Commit 13d16030 authored by Martin Reinecke's avatar Martin Reinecke

tweak GaussianEnergy

parent 9f070958
......@@ -50,6 +50,15 @@ class Linearization(object):
-self._val, -self._jac,
None if self._metric is None else -self._metric)
def conjugate(self):
return Linearization(
self._val.conjugate(), self._jac.conjugate(),
None if self._metric is None else self._metric.conjugate())
@property
def real(self):
return Linearization(self._val.real, self._jac.real)
def __add__(self, other):
if isinstance(other, Linearization):
from .operators.relaxed_sum_operator import RelaxedSumOperator
......
......@@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy import Energy
from ..linearization import Linearization
import numpy as np
class EnergyAdapter(Energy):
......@@ -16,7 +17,9 @@ class EnergyAdapter(Energy):
def _fill_all(self):
tmp = self._op(Linearization.make_var(self._position))
self._val = tmp.val.local_data[()]
self._val = tmp.val
if not np.isscalar(self._val):
self._val = self._val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp.metric
......@@ -24,6 +27,8 @@ class EnergyAdapter(Energy):
def value(self):
if self._val is None:
self._val = self._op(self._position)
if not np.isscalar(self._val):
self._val = self._val.local_data[()]
return self._val
@property
......
......@@ -26,6 +26,8 @@ from .sampling_enabler import SamplingEnabler
from ..sugar import makeOp
from ..linearization import Linearization
from .. import utilities
from ..field import Field
from .simple_linear_operators import VdotOperator
class EnergyOperator(Operator):
......@@ -46,6 +48,10 @@ class SquaredNormOperator(EnergyOperator):
return self._domain
def apply(self, x):
if isinstance(x, Linearization):
val = Field(self._target, x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac)
return Linearization(val, jac)
return Field(self._target, x.vdot(x))
......@@ -63,10 +69,11 @@ class QuadraticFormOperator(EnergyOperator):
def apply(self, x):
if isinstance(x, Linearization):
jac = self._op(x)
val = Field(self._target, 0.5 * x.vdot(jac))
t1 = self._op(x.val)
jac = VdotOperator(t1)(x.jac)
val = Field(self._target, 0.5*x.val.vdot(t1))
return Linearization(val, jac)
return Field(self._target, 0.5 * x.vdot(self._op(x)))
return Field(self._target, 0.5*x.vdot(self._op(x)))
class GaussianEnergy(EnergyOperator):
......@@ -82,6 +89,10 @@ class GaussianEnergy(EnergyOperator):
if self._domain is None:
raise ValueError("no domain given")
self._mean = mean
if covariance is None:
self._op = SquaredNormOperator(self._domain).scale(0.5)
else:
self._op = QuadraticFormOperator(covariance.inverse)
self._icov = None if covariance is None else covariance.inverse
def _checkEquivalence(self, newdom):
......@@ -97,8 +108,7 @@ class GaussianEnergy(EnergyOperator):
def apply(self, x):
residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual)
res = .5*residual.vdot(icovres)
res = self._op(residual)
if not isinstance(x, Linearization):
return res
metric = SandwichOperator.make(x.jac, self._icov)
......
......@@ -34,6 +34,11 @@ class Operator(NiftyMetaBase()):
from .simple_linear_operators import ConjugationOperator
return ConjugationOperator(self.target)(self)
@property
def real(self):
from .simple_linear_operators import Realizer
return Realizer(self.target)(self)
def __neg__(self):
return self.scale(-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment