Commit bae4a32b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve implementation of EnergySum and allow for linear combination of energies

parent ae802b86
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field
from ..multi import MultiField
from ..utilities import NiftyMetaBase, memo
......@@ -128,30 +129,18 @@ class Energy(NiftyMetaBase()):
"""
return None
def __add__(self, other):
if not isinstance(other, Energy):
raise TypeError
return Add(self, other)
def __mul__(self, factor):
from .energy_sum import EnergySum
if not isinstance(factor, (float, int)):
raise TypeError("Factor must be a real-valued scalar")
return EnergySum.make([self], [factor])
def __sub__(self, other):
def __rmul__(self, factor):
return self.__mul__(factor)
def __add__(self, other):
from .energy_sum import EnergySum
from ..sugar import full
if not isinstance(other, Energy):
raise TypeError
return Add(self, (-1) * other)
def Add(energy1, energy2):
if (isinstance(energy1.position, MultiField) and
isinstance(energy2.position, MultiField)):
a = energy1.position._val
b = energy2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif (isinstance(energy1.position, Field) and
isinstance(energy2.position, Field)):
position = energy1.position
else:
raise TypeError
from .energy_sum import EnergySum
return EnergySum(position, [energy1, energy2])
raise TypeError("can only add Energies to Energies")
return EnergySum.make([self, other])
......@@ -21,44 +21,60 @@ from ..utilities import memo
class EnergySum(Energy):
def __init__(self, position, energies, minimizer_controller=None,
preconditioner=None, precon_idx=None):
def __init__(self, position, energies, factors):
super(EnergySum, self).__init__(position=position)
self._energies = [energy.at(position) for energy in energies]
self._min_controller = minimizer_controller
self._preconditioner = preconditioner
self._precon_idx = precon_idx
self._energies = tuple(e.at(position) for e in energies)
self._factors = tuple(factors)
@staticmethod
def make(energies, factors=None):
if factors is None:
factors = (1,)*len(energies)
# unpack energies
eout = []
fout = []
EnergySum._unpackEnergies(energies, factors, 1., eout, fout)
for e in eout[1:]:
if not e.position.isEquivalentTo(eout[0].position):
raise ValueError("position mismatch")
return EnergySum(eout[0].position, eout, fout)
@staticmethod
def _unpackEnergies(e_in, f_in, prefactor, e_out, f_out):
for e, f in zip(e_in, f_in):
if isinstance(e, EnergySum):
EnergySum._unpackEnergies(e._energies, e._factors, prefactor*f,
e_out, f_out)
else:
e_out.append(e)
f_out.append(prefactor*f)
def at(self, position):
return self.__class__(position, self._energies, self._min_controller,
self._preconditioner, self._precon_idx)
return self.__class__(position, self._energies, self._factors)
@property
@memo
def value(self):
res = self._energies[0].value
for e in self._energies[1:]:
res += e.value
res = self._energies[0].value * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.value * f
return res
@property
@memo
def gradient(self):
res = self._energies[0].gradient.copy()
for e in self._energies[1:]:
res += e.gradient
res = self._energies[0].gradient.copy() if self._factors[0] == 1. \
else self._energies[0].gradient * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.gradient if f == 1. else f*e.gradient
return res.lock()
@property
@memo
def curvature(self):
res = self._energies[0].curvature
for e in self._energies[1:]:
res = res + e.curvature
if self._min_controller is None:
return res
precon = self._preconditioner
if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler
return InversionEnabler(res, self._min_controller, precon)
res = self._energies[0].curvature if self._factors[0] == 1. \
else self._energies[0].curvature * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res = res + (e.curvature if f == 1. else e.curvature*f)
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment