Commit 9ebbecf5 authored by Martin Reinecke's avatar Martin Reinecke

simplify Energy class

parent 92ff0e59
......@@ -66,7 +66,6 @@ from .minimization.scipy_minimizer import (ScipyMinimizer, NewtonCG, L_BFGS_B,
from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.line_energy import LineEnergy
from .minimization.energy_sum import EnergySum
from .sugar import *
from .plotting.plot import plot, plot_finish
......@@ -74,11 +73,11 @@ from .plotting.plot import plot, plot_finish
from .library.amplitude_model import AmplitudeModel
from .library.gaussian_energy import GaussianEnergy
from .library.los_response import LOSResponse
#from .library.point_sources import PointSources
# from .library.point_sources import PointSources
from .library.poissonian_energy import PoissonianEnergy
from .library.wiener_filter_curvature import WienerFilterCurvature
#from .library.correlated_fields import (make_correlated_field,
# make_mf_correlated_field)
# from .library.correlated_fields import (make_correlated_field,
# make_mf_correlated_field)
from .library.bernoulli_energy import BernoulliEnergy
from . import extra
......
......@@ -135,31 +135,6 @@ class Energy(NiftyMetaBase()):
raise TypeError
return MetricInversionEnabler(self, controller, preconditioner)
def __mul__(self, factor):
from .energy_sum import EnergySum
if isinstance(factor, (float, int)):
return EnergySum.make([self], [factor])
return NotImplemented
def __rmul__(self, factor):
return self.__mul__(factor)
def __add__(self, other):
from .energy_sum import EnergySum
if isinstance(other, Energy):
return EnergySum.make([self, other])
return NotImplemented
def __sub__(self, other):
from .energy_sum import EnergySum
if isinstance(other, Energy):
return EnergySum.make([self, other], [1., -1.])
return NotImplemented
def __neg__(self):
from .energy_sum import EnergySum
return EnergySum.make([self], [-1.])
class MetricInversionEnabler(Energy):
def __init__(self, ene, controller, preconditioner):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..utilities import memo, my_lincomb, my_lincomb_simple
from .energy import Energy
class EnergySum(Energy):
def __init__(self, position, energies, factors):
super(EnergySum, self).__init__(position=position)
self._energies = tuple(e.at(position) for e in energies)
self._factors = tuple(factors)
@staticmethod
def make(energies, factors=None):
if factors is None:
factors = (1,)*len(energies)
# unpack energies
eout = []
fout = []
EnergySum._unpackEnergies(energies, factors, 1., eout, fout)
# FIXME we need to reach an agreement about how we want to deal
# with domain and field compatibility. Until then, we relax te check.
# for e in eout[1:]:
# if not e.position.isEquivalentTo(eout[0].position):
# raise ValueError("position mismatch")
return EnergySum(eout[0].position, eout, fout)
@staticmethod
def _unpackEnergies(e_in, f_in, prefactor, e_out, f_out):
for e, f in zip(e_in, f_in):
if isinstance(e, EnergySum):
EnergySum._unpackEnergies(e._energies, e._factors, prefactor*f,
e_out, f_out)
else:
e_out.append(e)
f_out.append(prefactor*f)
def at(self, position):
return self.__class__(position, self._energies, self._factors)
@property
@memo
def value(self):
return my_lincomb_simple(map(lambda v: v.value, self._energies),
self._factors)
@property
@memo
def gradient(self):
return my_lincomb(map(lambda v: v.gradient, self._energies),
self._factors)
@property
@memo
def metric(self):
return my_lincomb(map(lambda v: v.metric, self._energies),
self._factors)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
import nifty5 as ift
import numpy as np
from numpy.testing import assert_allclose, assert_raises
class EnergySum_Tests(unittest.TestCase):
@expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789),
[ift.RGSpace(3), ift.RGSpace(12)],
{'a': ift.GLSpace(15), 'b': ift.RGSpace(64)}],
[4, 78, 23]))
def testSum(self, domain, seed):
np.random.seed(seed)
pos = ift.from_random("normal", domain)
A1 = ift.makeOp(ift.from_random("normal", domain))
b1 = ift.from_random("normal", domain)
E1 = ift.QuadraticEnergy(pos, A1, b1)
A2 = ift.makeOp(ift.from_random("normal", domain))
b2 = ift.from_random("normal", domain)
E2 = ift.QuadraticEnergy(pos, A2, b2)
assert_allclose((E1+E2).value, E1.value+E2.value)
assert_allclose((E1-E2).value, E1.value-E2.value)
assert_allclose((3.8*E1+E2*4).value, 3.8*E1.value+4*E2.value)
assert_allclose((-E1).value, -(E1.value))
with assert_raises(TypeError):
E1*E2
with assert_raises(TypeError):
E1*2j
with assert_raises(TypeError):
E1+2
with assert_raises(TypeError):
E1-"hello"
# with assert_raises(ValueError):
# E1+E2.at(2*pos)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment