From c24d275b39fe8c5ddb77bc1c9f1be56b6f2ba7ab Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Fri, 29 May 2020 12:46:44 +0200 Subject: [PATCH] Refactoring --- src/operators/energy_operators.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/operators/energy_operators.py b/src/operators/energy_operators.py index 827454d43..b7b242b4d 100644 --- a/src/operators/energy_operators.py +++ b/src/operators/energy_operators.py @@ -150,21 +150,21 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): """ def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype): - self._r = str(residual_key) - self._icov = str(inverse_covariance_key) + self._kr = str(residual_key) + self._ki = str(inverse_covariance_key) dom = DomainTuple.make(domain) - self._domain = MultiDomain.make({self._r: dom, self._icov: dom}) - self._sampling_dtype = sampling_dtype + self._domain = MultiDomain.make({self._kr: dom, self._ki: dom}) + self._dt = sampling_dtype _check_sampling_dtype(self._domain, sampling_dtype) def apply(self, x): self._check_input(x) - res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov].real).real - x[self._icov].ptw("log").sum()) + r, i = x[self._kr], x[self._ki] + res = 0.5*(r.vdot(r*i.real).real - i.ptw("log").sum()) if not x.want_metric: return res - mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)} - met = makeOp(MultiField.from_dict(mf)) - return res.add_metric(SamplingDtypeSetter(met, self._sampling_dtype)) + met = MultiField.from_dict({self._kr: i.val, self._ki: .5*i.val**(-2)}) + return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt)) class GaussianEnergy(EnergyOperator): -- GitLab