diff --git a/src/operators/energy_operators.py b/src/operators/energy_operators.py index 827454d430bd7fbd9dc6ba83516c4acbc428adf4..b7b242b4db6a9fd34e7a5e47a9f75919365c4635 100644 --- a/src/operators/energy_operators.py +++ b/src/operators/energy_operators.py @@ -150,21 +150,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): """ def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype): - self._r = str(residual_key) - self._icov = str(inverse_covariance_key) + self._kr = str(residual_key) + self._ki = str(inverse_covariance_key) dom = DomainTuple.make(domain) - self._domain = MultiDomain.make({self._r: dom, self._icov: dom}) - self._sampling_dtype = sampling_dtype + self._domain = MultiDomain.make({self._kr: dom, self._ki: dom}) + self._dt = sampling_dtype _check_sampling_dtype(self._domain, sampling_dtype) def apply(self, x): self._check_input(x) - res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov].real).real - x[self._icov].ptw("log").sum()) + r, i = x[self._kr], x[self._ki] + res = 0.5*(r.vdot(r*i.real).real - i.ptw("log").sum()) if not x.want_metric: return res - mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)} - met = makeOp(MultiField.from_dict(mf)) - return res.add_metric(SamplingDtypeSetter(met, self._sampling_dtype)) + met = MultiField.from_dict({self._kr: i.val, self._ki: .5*i.val**(-2)}) + return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt)) class GaussianEnergy(EnergyOperator):