Commit 55112079 authored by Philipp Arras's avatar Philipp Arras

Addition for energies

parent 0ce57ca4
from .wiener_filter_energy import WienerFilterEnergy
from .wiener_filter_curvature import WienerFilterCurvature
from .los_response import LOSResponse
from .noise_energy import NoiseEnergy
from .nonlinear_power_energy import NonlinearPowerEnergy
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .nonlinearities import Exponential, Linear, PositiveTanh, Tanh
from .poisson_energy import PoissonEnergy
from .nonlinearities import Exponential, Linear, Tanh, PositiveTanh
from .los_response import LOSResponse
from .unit_log_gauss import UnitLogGauss
from .wiener_filter_curvature import WienerFilterCurvature
from .wiener_filter_energy import WienerFilterEnergy
......@@ -16,7 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..utilities import memo, NiftyMetaBase
from ..field import Field
from ..multi import MultiField
from ..utilities import NiftyMetaBase, memo
class Energy(NiftyMetaBase()):
......@@ -125,3 +127,53 @@ class Energy(NiftyMetaBase()):
`dir`. If None, the step size is not limited.
"""
return None
def __add__(self, other):
assert isinstance(other, Energy)
return Add(self, other)
def __sub__(self, other):
assert isinstance(other, Energy)
return Add(self, (-1) * other)
class Add(Energy):
"""
Please not: If you add two operators which share some keys in the position
but have different values there, it is not guaranteed which value will be
used for the sum. You shouldn't do that anyways.
"""
def __init__(self, op1, op2):
if isinstance(op1.position, MultiField) and isinstance(op2.position, MultiField):
a = op1.position._val
b = op2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif isinstance(op1.position, Field) and isinstance(op2.position, Field):
position = op1.position
else:
raise TypeError
super(Add, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
def at(self, position):
return self.__class__(self._op1.at(position), self._op2.at(position))
@property
@memo
def value(self):
return self._op1.value + self._op2.value
@property
@memo
def gradient(self):
return self._op1.gradient + self._op2.gradient
@property
@memo
def curvature(self):
return self._op1.curvature + self._op2.curvature
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment