Commit 13d16030 by Martin Reinecke

### tweak GaussianEnergy

parent 9f070958
 ... @@ -50,6 +50,15 @@ class Linearization(object): ... @@ -50,6 +50,15 @@ class Linearization(object): -self._val, -self._jac, -self._val, -self._jac, None if self._metric is None else -self._metric) None if self._metric is None else -self._metric) def conjugate(self): return Linearization( self._val.conjugate(), self._jac.conjugate(), None if self._metric is None else self._metric.conjugate()) @property def real(self): return Linearization(self._val.real, self._jac.real) def __add__(self, other): def __add__(self, other): if isinstance(other, Linearization): if isinstance(other, Linearization): from .operators.relaxed_sum_operator import RelaxedSumOperator from .operators.relaxed_sum_operator import RelaxedSumOperator ... ...
 ... @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function ... @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function from ..compat import * from ..compat import * from ..minimization.energy import Energy from ..minimization.energy import Energy from ..linearization import Linearization from ..linearization import Linearization import numpy as np class EnergyAdapter(Energy): class EnergyAdapter(Energy): ... @@ -16,7 +17,9 @@ class EnergyAdapter(Energy): ... @@ -16,7 +17,9 @@ class EnergyAdapter(Energy): def _fill_all(self): def _fill_all(self): tmp = self._op(Linearization.make_var(self._position)) tmp = self._op(Linearization.make_var(self._position)) self._val = tmp.val.local_data[()] self._val = tmp.val if not np.isscalar(self._val): self._val = self._val.local_data[()] self._grad = tmp.gradient self._grad = tmp.gradient self._metric = tmp.metric self._metric = tmp.metric ... @@ -24,6 +27,8 @@ class EnergyAdapter(Energy): ... @@ -24,6 +27,8 @@ class EnergyAdapter(Energy): def value(self): def value(self): if self._val is None: if self._val is None: self._val = self._op(self._position) self._val = self._op(self._position) if not np.isscalar(self._val): self._val = self._val.local_data[()] return self._val return self._val @property @property ... ...
 ... @@ -26,6 +26,8 @@ from .sampling_enabler import SamplingEnabler ... @@ -26,6 +26,8 @@ from .sampling_enabler import SamplingEnabler from ..sugar import makeOp from ..sugar import makeOp from ..linearization import Linearization from ..linearization import Linearization from .. import utilities from .. import utilities from ..field import Field from .simple_linear_operators import VdotOperator class EnergyOperator(Operator): class EnergyOperator(Operator): ... @@ -46,6 +48,10 @@ class SquaredNormOperator(EnergyOperator): ... @@ -46,6 +48,10 @@ class SquaredNormOperator(EnergyOperator): return self._domain return self._domain def apply(self, x): def apply(self, x): if isinstance(x, Linearization): val = Field(self._target, x.val.vdot(x.val)) jac = VdotOperator(2*x.val)(x.jac) return Linearization(val, jac) return Field(self._target, x.vdot(x)) return Field(self._target, x.vdot(x)) ... @@ -63,10 +69,11 @@ class QuadraticFormOperator(EnergyOperator): ... @@ -63,10 +69,11 @@ class QuadraticFormOperator(EnergyOperator): def apply(self, x): def apply(self, x): if isinstance(x, Linearization): if isinstance(x, Linearization): jac = self._op(x) t1 = self._op(x.val) val = Field(self._target, 0.5 * x.vdot(jac)) jac = VdotOperator(t1)(x.jac) val = Field(self._target, 0.5*x.val.vdot(t1)) return Linearization(val, jac) return Linearization(val, jac) return Field(self._target, 0.5 * x.vdot(self._op(x))) return Field(self._target, 0.5*x.vdot(self._op(x))) class GaussianEnergy(EnergyOperator): class GaussianEnergy(EnergyOperator): ... @@ -82,6 +89,10 @@ class GaussianEnergy(EnergyOperator): ... @@ -82,6 +89,10 @@ class GaussianEnergy(EnergyOperator): if self._domain is None: if self._domain is None: raise ValueError("no domain given") raise ValueError("no domain given") self._mean = mean self._mean = mean if covariance is None: self._op = SquaredNormOperator(self._domain).scale(0.5) else: self._op = QuadraticFormOperator(covariance.inverse) self._icov = None if covariance is None else covariance.inverse self._icov = None if covariance is None else covariance.inverse def _checkEquivalence(self, newdom): def _checkEquivalence(self, newdom): ... @@ -97,8 +108,7 @@ class GaussianEnergy(EnergyOperator): ... @@ -97,8 +108,7 @@ class GaussianEnergy(EnergyOperator): def apply(self, x): def apply(self, x): residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean icovres = residual if self._icov is None else self._icov(residual) res = self._op(residual) res = .5*residual.vdot(icovres) if not isinstance(x, Linearization): if not isinstance(x, Linearization): return res return res metric = SandwichOperator.make(x.jac, self._icov) metric = SandwichOperator.make(x.jac, self._icov) ... ...
 ... @@ -34,6 +34,11 @@ class Operator(NiftyMetaBase()): ... @@ -34,6 +34,11 @@ class Operator(NiftyMetaBase()): from .simple_linear_operators import ConjugationOperator from .simple_linear_operators import ConjugationOperator return ConjugationOperator(self.target)(self) return ConjugationOperator(self.target)(self) @property def real(self): from .simple_linear_operators import Realizer return Realizer(self.target)(self) def __neg__(self): def __neg__(self): return self.scale(-1) return self.scale(-1) ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!