Commit 227bdaea authored by Philipp Arras's avatar Philipp Arras
Browse files

Add mean and covariance to gaussian_energy

parent a9a459c4
......@@ -17,45 +17,53 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from import Energy
from ..operators.inversion_enabler import InversionEnabler
from ..operators.sandwich_operator import SandwichOperator
from ..utilities import memo
class GaussianEnergy(Energy):
def __init__(self, s, inverter=None):
def __init__(self, inp, mean=None, covariance=None):
s: Sky model object
inp: Sky model object
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
super(GaussianEnergy, self).__init__(s.position)
self._s = s
self._inverter = inverter
super(GaussianEnergy, self).__init__(inp.position)
self._inp = inp
self._mean = mean
self._cov = covariance
def at(self, position):
return self.__class__(, self._inverter)
return self.__class__(, self._mean, self._cov)
def _gradient_helper(self):
return self._s.gradient
def residual(self):
if self._mean is not None:
return self._inp.value - self._mean
return self._inp.value
def value(self):
return .5 * self._s.value.squared_norm()
if self._cov is not None:
return .5 * self.residual.vdot(self._cov.inverse(self.residual))
return .5 * self.residual.vdot(self.residual).real
def _gradient_helper(self):
return self._inp.gradient
def gradient(self):
return self._gradient_helper.adjoint(self._s.value)
if self._cov is not None:
return self._gradient_helper.adjoint(self._cov.inverse(self.residual))
return self._gradient_helper.adjoint(self.residual)
def curvature(self):
c = SandwichOperator.make(self._gradient_helper)
if self._inverter is None:
return c
return InversionEnabler(c, self._inverter)
return SandwichOperator.make(self._gradient_helper)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment