Commit 227bdaea authored by Philipp Arras's avatar Philipp Arras
Browse files

Add mean and covariance to gaussian_energy

parent a9a459c4
...@@ -17,45 +17,53 @@ ...@@ -17,45 +17,53 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.inversion_enabler import InversionEnabler
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..utilities import memo from ..utilities import memo
class GaussianEnergy(Energy): class GaussianEnergy(Energy):
def __init__(self, s, inverter=None): def __init__(self, inp, mean=None, covariance=None):
""" """
s: Sky model object inp: Sky model object
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance covariance
""" """
super(GaussianEnergy, self).__init__(s.position) super(GaussianEnergy, self).__init__(inp.position)
self._s = s self._inp = inp
self._inverter = inverter self._mean = mean
self._cov = covariance
def at(self, position): def at(self, position):
return self.__class__(self._s.at(position), self._inverter) return self.__class__(self._inp.at(position), self._mean, self._cov)
@property @property
@memo @memo
def _gradient_helper(self): def residual(self):
return self._s.gradient if self._mean is not None:
return self._inp.value - self._mean
return self._inp.value
@property @property
@memo @memo
def value(self): def value(self):
return .5 * self._s.value.squared_norm() if self._cov is not None:
return .5 * self.residual.vdot(self._cov.inverse(self.residual))
return .5 * self.residual.vdot(self.residual).real
@property
@memo
def _gradient_helper(self):
return self._inp.gradient
@property @property
@memo @memo
def gradient(self): def gradient(self):
return self._gradient_helper.adjoint(self._s.value) if self._cov is not None:
return self._gradient_helper.adjoint(self._cov.inverse(self.residual))
return self._gradient_helper.adjoint(self.residual)
@property @property
@memo @memo
def curvature(self): def curvature(self):
c = SandwichOperator.make(self._gradient_helper) return SandwichOperator.make(self._gradient_helper)
if self._inverter is None:
return c
return InversionEnabler(c, self._inverter)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment