Commit 178ca953 authored by Philipp Arras's avatar Philipp Arras

Merge branch 'new_energies' into 'NIFTy_5'

New energies

See merge request ift/NIFTy!269
parents 0ce57ca4 b61562e0
Pipeline #31245 passed with stages
in 1 minute and 24 seconds
from .wiener_filter_energy import WienerFilterEnergy
from .wiener_filter_curvature import WienerFilterCurvature
from .los_response import LOSResponse
from .noise_energy import NoiseEnergy
from .nonlinear_power_energy import NonlinearPowerEnergy
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .nonlinearities import Exponential, Linear, PositiveTanh, Tanh
from .poisson_energy import PoissonEnergy
from .nonlinearities import Exponential, Linear, Tanh, PositiveTanh
from .los_response import LOSResponse
from .unit_log_gauss import UnitLogGauss
from .wiener_filter_curvature import WienerFilterCurvature
from .wiener_filter_energy import WienerFilterEnergy
from ..minimization import Energy
from ..operators import InversionEnabler, SandwichOperator
from ..utilities import memo
class UnitLogGauss(Energy):
def __init__(self, position, s, inverter=None):
"""
s: Sky model object
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance
"""
super(UnitLogGauss, self).__init__(position)
self._s = s.at(position)
self._inverter = inverter
def at(self, position):
return self.__class__(position, self._s, self._inverter)
@property
@memo
def _gradient_helper(self):
return self._s.gradient
@property
@memo
def value(self):
return .5 * self._s.value.vdot(self._s.value)
@property
@memo
def gradient(self):
return self._gradient_helper.adjoint(self._s.value)
@property
@memo
def curvature(self):
c = SandwichOperator.make(self._gradient_helper)
if self._inverter is None:
return c
return InversionEnabler(c, self._inverter)
......@@ -16,7 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..utilities import memo, NiftyMetaBase
from ..field import Field
from ..multi import MultiField
from ..utilities import NiftyMetaBase, memo
class Energy(NiftyMetaBase()):
......@@ -125,3 +127,27 @@ class Energy(NiftyMetaBase()):
`dir`. If None, the step size is not limited.
"""
return None
def __add__(self, other):
assert isinstance(other, Energy)
return Add(self, other)
def __sub__(self, other):
assert isinstance(other, Energy)
return Add(self, (-1) * other)
def Add(energy1, energy2):
if isinstance(energy1.position, MultiField) and isinstance(energy2.position, MultiField):
a = energy1.position._val
b = energy2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif isinstance(energy1.position, Field) and isinstance(energy2.position, Field):
position = energy1.position
else:
raise TypeError
from .energy_sum import EnergySum
return EnergySum(position, [energy1, energy2])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment