Skip to content
Snippets Groups Projects

Fix dtype handling

Merged Philipp Arras requested to merge fixesinvcov into NIFTy_7
1 file
+ 9
4
Compare changes
  • Side-by-side
  • Inline
@@ -154,16 +154,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
@@ -154,16 +154,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self._ki = str(inverse_covariance_key)
self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
self._dt = sampling_dtype
self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, sampling_dtype)
_check_sampling_dtype(self._domain, self._dt)
 
self._cplx = np.issubdtype(sampling_dtype, np.complexfloating)
def apply(self, x):
def apply(self, x):
self._check_input(x)
self._check_input(x)
r, i = x[self._kr], x[self._ki]
r, i = x[self._kr], x[self._ki]
res = 0.5*(r.vdot(r*i.real).real - i.ptw("log").sum())
if self._cplx:
 
res = 0.5*r.vdot(r*i.real).real - i.ptw("log").sum()
 
else:
 
res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
if not x.want_metric:
if not x.want_metric:
return res
return res
met = MultiField.from_dict({self._kr: i.val, self._ki: .5*i.val**(-2)})
met = i.val if self._cplx else 0.5*i.val
 
met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
Loading