Commit be351767 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'tweak_energy_sum' into 'NIFTy_5'

improve implementation of EnergySum and allow for linear combination of energies

See merge request ift/nifty-dev!15
parents ae802b86 6b9d5cb4
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field
from ..multi import MultiField
from ..utilities import NiftyMetaBase, memo
......@@ -128,30 +129,27 @@ class Energy(NiftyMetaBase()):
"""
return None
def __mul__(self, factor):
from .energy_sum import EnergySum
if isinstance(factor, (float, int)):
return EnergySum.make([self], [factor])
return NotImplemented
def __rmul__(self, factor):
return self.__mul__(factor)
def __add__(self, other):
if not isinstance(other, Energy):
raise TypeError
return Add(self, other)
from .energy_sum import EnergySum
if isinstance(other, Energy):
return EnergySum.make([self, other])
return NotImplemented
def __sub__(self, other):
if not isinstance(other, Energy):
raise TypeError
return Add(self, (-1) * other)
def Add(energy1, energy2):
if (isinstance(energy1.position, MultiField) and
isinstance(energy2.position, MultiField)):
a = energy1.position._val
b = energy2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
position = MultiField(ab)
elif (isinstance(energy1.position, Field) and
isinstance(energy2.position, Field)):
position = energy1.position
else:
raise TypeError
from .energy_sum import EnergySum
return EnergySum(position, [energy1, energy2])
from .energy_sum import EnergySum
if isinstance(other, Energy):
return EnergySum.make([self, other], [1., -1.])
return NotImplemented
def __neg__(self):
from .energy_sum import EnergySum
return EnergySum.make([self], [-1.])
......@@ -16,49 +16,65 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .energy import Energy
from ..utilities import memo
from .energy import Energy
class EnergySum(Energy):
def __init__(self, position, energies, minimizer_controller=None,
preconditioner=None, precon_idx=None):
def __init__(self, position, energies, factors):
super(EnergySum, self).__init__(position=position)
self._energies = [energy.at(position) for energy in energies]
self._min_controller = minimizer_controller
self._preconditioner = preconditioner
self._precon_idx = precon_idx
self._energies = tuple(e.at(position) for e in energies)
self._factors = tuple(factors)
@staticmethod
def make(energies, factors=None):
if factors is None:
factors = (1,)*len(energies)
# unpack energies
eout = []
fout = []
EnergySum._unpackEnergies(energies, factors, 1., eout, fout)
for e in eout[1:]:
if not e.position.isEquivalentTo(eout[0].position):
raise ValueError("position mismatch")
return EnergySum(eout[0].position, eout, fout)
@staticmethod
def _unpackEnergies(e_in, f_in, prefactor, e_out, f_out):
for e, f in zip(e_in, f_in):
if isinstance(e, EnergySum):
EnergySum._unpackEnergies(e._energies, e._factors, prefactor*f,
e_out, f_out)
else:
e_out.append(e)
f_out.append(prefactor*f)
def at(self, position):
return self.__class__(position, self._energies, self._min_controller,
self._preconditioner, self._precon_idx)
return self.__class__(position, self._energies, self._factors)
@property
@memo
def value(self):
res = self._energies[0].value
for e in self._energies[1:]:
res += e.value
res = self._energies[0].value * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.value * f
return res
@property
@memo
def gradient(self):
res = self._energies[0].gradient.copy()
for e in self._energies[1:]:
res += e.gradient
res = self._energies[0].gradient.copy() if self._factors[0] == 1. \
else self._energies[0].gradient * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.gradient if f == 1. else f*e.gradient
return res.lock()
@property
@memo
def curvature(self):
res = self._energies[0].curvature
for e in self._energies[1:]:
res = res + e.curvature
if self._min_controller is None:
return res
precon = self._preconditioner
if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler
return InversionEnabler(res, self._min_controller, precon)
res = self._energies[0].curvature if self._factors[0] == 1. \
else self._energies[0].curvature * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res = res + (e.curvature if f == 1. else e.curvature*f)
return res
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
import nifty5 as ift
import numpy as np
from numpy.testing import assert_allclose, assert_raises
class EnergySum_Tests(unittest.TestCase):
@expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789),
[ift.RGSpace(3), ift.RGSpace(12)],
{'a': ift.GLSpace(15), 'b': ift.RGSpace(64)}],
[4, 78, 23]))
def testSum(self, domain, seed):
np.random.seed(seed)
pos = ift.from_random("normal", domain)
A1 = ift.makeOp(ift.from_random("normal", domain))
b1 = ift.from_random("normal", domain)
E1 = ift.QuadraticEnergy(pos, A1, b1)
A2 = ift.makeOp(ift.from_random("normal", domain))
b2 = ift.from_random("normal", domain)
E2 = ift.QuadraticEnergy(pos, A2, b2)
assert_allclose((E1+E2).value, E1.value+E2.value)
assert_allclose((E1-E2).value, E1.value-E2.value)
assert_allclose((3.8*E1+E2*4).value, 3.8*E1.value+4*E2.value)
assert_allclose((-E1).value, -(E1.value))
with assert_raises(TypeError):
E1*E2
with assert_raises(TypeError):
E1*2j
with assert_raises(TypeError):
E1+2
with assert_raises(TypeError):
E1-"hello"
with assert_raises(ValueError):
E1+E2.at(2*pos)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment