Commit 13d16030 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak GaussianEnergy

parent 9f070958
...@@ -50,6 +50,15 @@ class Linearization(object): ...@@ -50,6 +50,15 @@ class Linearization(object):
-self._val, -self._jac, -self._val, -self._jac,
None if self._metric is None else -self._metric) None if self._metric is None else -self._metric)
def conjugate(self):
return Linearization(
self._val.conjugate(), self._jac.conjugate(),
None if self._metric is None else self._metric.conjugate())
@property
def real(self):
return Linearization(self._val.real, self._jac.real)
def __add__(self, other): def __add__(self, other):
if isinstance(other, Linearization): if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator from .operators.relaxed_sum_operator import RelaxedSumOperator
......
...@@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function ...@@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function
from ..compat import * from ..compat import *
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..linearization import Linearization from ..linearization import Linearization
import numpy as np
class EnergyAdapter(Energy): class EnergyAdapter(Energy):
...@@ -16,7 +17,9 @@ class EnergyAdapter(Energy): ...@@ -16,7 +17,9 @@ class EnergyAdapter(Energy):
def _fill_all(self): def _fill_all(self):
tmp = self._op(Linearization.make_var(self._position)) tmp = self._op(Linearization.make_var(self._position))
self._val = tmp.val.local_data[()] self._val = tmp.val
if not np.isscalar(self._val):
self._val = self._val.local_data[()]
self._grad = tmp.gradient self._grad = tmp.gradient
self._metric = tmp.metric self._metric = tmp.metric
...@@ -24,6 +27,8 @@ class EnergyAdapter(Energy): ...@@ -24,6 +27,8 @@ class EnergyAdapter(Energy):
def value(self): def value(self):
if self._val is None: if self._val is None:
self._val = self._op(self._position) self._val = self._op(self._position)
if not np.isscalar(self._val):
self._val = self._val.local_data[()]
return self._val return self._val
@property @property
......
...@@ -26,6 +26,8 @@ from .sampling_enabler import SamplingEnabler ...@@ -26,6 +26,8 @@ from .sampling_enabler import SamplingEnabler
from ..sugar import makeOp from ..sugar import makeOp
from ..linearization import Linearization from ..linearization import Linearization
from .. import utilities from .. import utilities
from ..field import Field
from .simple_linear_operators import VdotOperator
class EnergyOperator(Operator): class EnergyOperator(Operator):
...@@ -46,6 +48,10 @@ class SquaredNormOperator(EnergyOperator): ...@@ -46,6 +48,10 @@ class SquaredNormOperator(EnergyOperator):
return self._domain return self._domain
def apply(self, x): def apply(self, x):
if isinstance(x, Linearization):
val = Field(self._target, x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac)
return Linearization(val, jac)
return Field(self._target, x.vdot(x)) return Field(self._target, x.vdot(x))
...@@ -63,10 +69,11 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -63,10 +69,11 @@ class QuadraticFormOperator(EnergyOperator):
def apply(self, x): def apply(self, x):
if isinstance(x, Linearization): if isinstance(x, Linearization):
jac = self._op(x) t1 = self._op(x.val)
val = Field(self._target, 0.5 * x.vdot(jac)) jac = VdotOperator(t1)(x.jac)
val = Field(self._target, 0.5*x.val.vdot(t1))
return Linearization(val, jac) return Linearization(val, jac)
return Field(self._target, 0.5 * x.vdot(self._op(x))) return Field(self._target, 0.5*x.vdot(self._op(x)))
class GaussianEnergy(EnergyOperator): class GaussianEnergy(EnergyOperator):
...@@ -82,6 +89,10 @@ class GaussianEnergy(EnergyOperator): ...@@ -82,6 +89,10 @@ class GaussianEnergy(EnergyOperator):
if self._domain is None: if self._domain is None:
raise ValueError("no domain given") raise ValueError("no domain given")
self._mean = mean self._mean = mean
if covariance is None:
self._op = SquaredNormOperator(self._domain).scale(0.5)
else:
self._op = QuadraticFormOperator(covariance.inverse)
self._icov = None if covariance is None else covariance.inverse self._icov = None if covariance is None else covariance.inverse
def _checkEquivalence(self, newdom): def _checkEquivalence(self, newdom):
...@@ -97,8 +108,7 @@ class GaussianEnergy(EnergyOperator): ...@@ -97,8 +108,7 @@ class GaussianEnergy(EnergyOperator):
def apply(self, x): def apply(self, x):
residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual) res = self._op(residual)
res = .5*residual.vdot(icovres)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return res return res
metric = SandwichOperator.make(x.jac, self._icov) metric = SandwichOperator.make(x.jac, self._icov)
......
...@@ -34,6 +34,11 @@ class Operator(NiftyMetaBase()): ...@@ -34,6 +34,11 @@ class Operator(NiftyMetaBase()):
from .simple_linear_operators import ConjugationOperator from .simple_linear_operators import ConjugationOperator
return ConjugationOperator(self.target)(self) return ConjugationOperator(self.target)(self)
@property
def real(self):
from .simple_linear_operators import Realizer
return Realizer(self.target)(self)
def __neg__(self): def __neg__(self):
return self.scale(-1) return self.scale(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment