Skip to content
Snippets Groups Projects
Commit a29abca8 authored by Philipp Arras's avatar Philipp Arras
Browse files

Cosmetics

parent d159522a
Branches
Tags
1 merge request!417Performance pa
Pipeline #70453 failed
...@@ -19,17 +19,17 @@ import numpy as np ...@@ -19,17 +19,17 @@ import numpy as np
from .. import utilities from .. import utilities
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..field import Field from ..field import Field
from ..multi_field import MultiField
from ..linearization import Linearization from ..linearization import Linearization
from ..sugar import makeDomain, makeOp, full from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .operator import Operator from .operator import Operator
from .sampling_enabler import SamplingEnabler from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator, FieldAdapter from .simple_linear_operators import FieldAdapter, VdotOperator
class EnergyOperator(Operator): class EnergyOperator(Operator):
...@@ -130,14 +130,12 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -130,14 +130,12 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
from .contraction_operator import ContractionOperator
lin = isinstance(x, Linearization) lin = isinstance(x, Linearization)
r = FieldAdapter(self._domain[self._r], self._r) r = FieldAdapter(self._domain[self._r], self._r)
icov = FieldAdapter(self._domain[self._icov], self._icov) icov = FieldAdapter(self._domain[self._icov], self._icov)
res0 = r.vdot(r*icov).real res0 = r.vdot(r*icov).real
res1 = icov.log().sum() res1 = icov.log().sum()
res = 0.5*(res0-res1) res = (res0-res1).scale(0.5)(x)
res = res(x)
if not lin: if not lin:
return Field.scalar(res) return Field.scalar(res)
if not x.want_metric: if not x.want_metric:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment