Commit 1299dd5e authored by Reimar H Leike's avatar Reimar H Leike
Browse files

Merge branch 'fixesinvcov' into 'NIFTy_7'

Fix dtype handling

See merge request !525
parents 682a34b7 bc9b978a
Pipeline #76028 passed with stages
in 23 minutes and 41 seconds
......@@ -154,16 +154,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
self._dt = sampling_dtype
_check_sampling_dtype(self._domain, sampling_dtype)
self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, self._dt)
self._cplx = np.issubdtype(sampling_dtype, np.complexfloating)
def apply(self, x):
self._check_input(x)
r, i = x[self._kr], x[self._ki]
res = 0.5*(r.vdot(r*i.real).real - i.ptw("log").sum())
if self._cplx:
res = 0.5*r.vdot(r*i.real).real - i.ptw("log").sum()
else:
res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
if not x.want_metric:
return res
met = MultiField.from_dict({self._kr: i.val, self._ki: .5*i.val**(-2)})
met = i.val if self._cplx else 0.5*i.val
met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment