Skip to content
Snippets Groups Projects
Commit c24d275b authored by Philipp Arras's avatar Philipp Arras
Browse files

Refactoring

parent 6251760d
No related branches found
No related tags found
2 merge requests!535Nifty 7,!509Support complex data in `VariableCovarianceGaussianEnergy` and use simplify for constant input for KL and `EnergyAdapter`
Pipeline #75764 passed
......@@ -150,21 +150,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
"""
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
self._r = str(residual_key)
self._icov = str(inverse_covariance_key)
self._kr = str(residual_key)
self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
self._sampling_dtype = sampling_dtype
self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
self._dt = sampling_dtype
_check_sampling_dtype(self._domain, sampling_dtype)
def apply(self, x):
self._check_input(x)
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov].real).real - x[self._icov].ptw("log").sum())
r, i = x[self._kr], x[self._ki]
res = 0.5*(r.vdot(r*i.real).real - i.ptw("log").sum())
if not x.want_metric:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
met = makeOp(MultiField.from_dict(mf))
return res.add_metric(SamplingDtypeSetter(met, self._sampling_dtype))
met = MultiField.from_dict({self._kr: i.val, self._ki: .5*i.val**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
class GaussianEnergy(EnergyOperator):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment