Commit c8cfecf6 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'NIFTy_5' of gitlab.mpcdf.mpg.de:ift/NIFTy into NIFTy_5

parents 6c8d82e8 4ba18e98
Pipeline #31328 passed with stages
in 2 minutes and 6 seconds
......@@ -6,15 +6,15 @@ from ..sugar import log, makeOp
class PoissonLogLikelihood(Energy):
def __init__(self, position, lamb, d):
def __init__(self, lamb, d):
"""
s: Sky model object
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance
"""
super(PoissonLogLikelihood, self).__init__(position)
self._lamb = lamb.at(position)
super(PoissonLogLikelihood, self).__init__(lamb.position)
self._lamb = lamb
self._d = d
lamb_val = self._lamb.value
......@@ -29,7 +29,7 @@ class PoissonLogLikelihood(Energy):
self._curvature = SandwichOperator.make(self._lamb.gradient, metric)
def at(self, position):
return self.__class__(position, self._lamb, self._d)
return self.__class__(self._lamb.at(position), self._d)
@property
def value(self):
......
......@@ -4,19 +4,19 @@ from ..utilities import memo
class UnitLogGauss(Energy):
def __init__(self, position, s, inverter=None):
def __init__(self, s, inverter=None):
"""
s: Sky model object
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance
"""
super(UnitLogGauss, self).__init__(position)
self._s = s.at(position)
super(UnitLogGauss, self).__init__(s.position)
self._s = s
self._inverter = inverter
def at(self, position):
return self.__class__(position, self._s, self._inverter)
return self.__class__(self._s.at(position), self._inverter)
@property
@memo
......
......@@ -129,11 +129,13 @@ class Energy(NiftyMetaBase()):
return None
def __add__(self, other):
assert isinstance(other, Energy)
if not isinstance(other, Energy):
raise TypeError
return Add(self, other)
def __sub__(self, other):
assert isinstance(other, Energy)
if not isinstance(other, Energy):
raise TypeError
return Add(self, (-1) * other)
......
......@@ -48,7 +48,8 @@ class MultiSkyGradientOperator(LinearOperator):
# Needed if gradients == {}
if res is None:
res = full(self.target, 0.)
assert res.domain == self.target
if not res.domain == self.target:
raise TypeError
else:
grad_keys = self._gradients.keys()
res = {}
......@@ -58,5 +59,6 @@ class MultiSkyGradientOperator(LinearOperator):
else:
res[dd] = full(self.domain[dd], 0.)
res = MultiField(res)
assert res.domain == self.domain
if not res.domain == self.domain:
raise TypeError
return res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment