### simplify Energy class

parent 92ff0e59
 ... ... @@ -66,7 +66,6 @@ from .minimization.scipy_minimizer import (ScipyMinimizer, NewtonCG, L_BFGS_B, from .minimization.energy import Energy from .minimization.quadratic_energy import QuadraticEnergy from .minimization.line_energy import LineEnergy from .minimization.energy_sum import EnergySum from .sugar import * from .plotting.plot import plot, plot_finish ... ... @@ -74,11 +73,11 @@ from .plotting.plot import plot, plot_finish from .library.amplitude_model import AmplitudeModel from .library.gaussian_energy import GaussianEnergy from .library.los_response import LOSResponse #from .library.point_sources import PointSources # from .library.point_sources import PointSources from .library.poissonian_energy import PoissonianEnergy from .library.wiener_filter_curvature import WienerFilterCurvature #from .library.correlated_fields import (make_correlated_field, # make_mf_correlated_field) # from .library.correlated_fields import (make_correlated_field, # make_mf_correlated_field) from .library.bernoulli_energy import BernoulliEnergy from . import extra ... ...
 ... ... @@ -135,31 +135,6 @@ class Energy(NiftyMetaBase()): raise TypeError return MetricInversionEnabler(self, controller, preconditioner) def __mul__(self, factor): from .energy_sum import EnergySum if isinstance(factor, (float, int)): return EnergySum.make([self], [factor]) return NotImplemented def __rmul__(self, factor): return self.__mul__(factor) def __add__(self, other): from .energy_sum import EnergySum if isinstance(other, Energy): return EnergySum.make([self, other]) return NotImplemented def __sub__(self, other): from .energy_sum import EnergySum if isinstance(other, Energy): return EnergySum.make([self, other], [1., -1.]) return NotImplemented def __neg__(self): from .energy_sum import EnergySum return EnergySum.make([self], [-1.]) class MetricInversionEnabler(Energy): def __init__(self, ene, controller, preconditioner): ... ...
 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2018 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. from __future__ import absolute_import, division, print_function from ..compat import * from ..utilities import memo, my_lincomb, my_lincomb_simple from .energy import Energy class EnergySum(Energy): def __init__(self, position, energies, factors): super(EnergySum, self).__init__(position=position) self._energies = tuple(e.at(position) for e in energies) self._factors = tuple(factors) @staticmethod def make(energies, factors=None): if factors is None: factors = (1,)*len(energies) # unpack energies eout = [] fout = [] EnergySum._unpackEnergies(energies, factors, 1., eout, fout) # FIXME we need to reach an agreement about how we want to deal # with domain and field compatibility. Until then, we relax te check. # for e in eout[1:]: # if not e.position.isEquivalentTo(eout.position): # raise ValueError("position mismatch") return EnergySum(eout.position, eout, fout) @staticmethod def _unpackEnergies(e_in, f_in, prefactor, e_out, f_out): for e, f in zip(e_in, f_in): if isinstance(e, EnergySum): EnergySum._unpackEnergies(e._energies, e._factors, prefactor*f, e_out, f_out) else: e_out.append(e) f_out.append(prefactor*f) def at(self, position): return self.__class__(position, self._energies, self._factors) @property @memo def value(self): return my_lincomb_simple(map(lambda v: v.value, self._energies), self._factors) @property @memo def gradient(self): return my_lincomb(map(lambda v: v.gradient, self._energies), self._factors) @property @memo def metric(self): return my_lincomb(map(lambda v: v.metric, self._energies), self._factors)
 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2018 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. import unittest from itertools import product from test.common import expand import nifty5 as ift import numpy as np from numpy.testing import assert_allclose, assert_raises class EnergySum_Tests(unittest.TestCase): @expand(product([ift.GLSpace(15), ift.RGSpace(64, distances=.789), ift.RGSpace([32, 32], distances=.789), [ift.RGSpace(3), ift.RGSpace(12)], {'a': ift.GLSpace(15), 'b': ift.RGSpace(64)}], [4, 78, 23])) def testSum(self, domain, seed): np.random.seed(seed) pos = ift.from_random("normal", domain) A1 = ift.makeOp(ift.from_random("normal", domain)) b1 = ift.from_random("normal", domain) E1 = ift.QuadraticEnergy(pos, A1, b1) A2 = ift.makeOp(ift.from_random("normal", domain)) b2 = ift.from_random("normal", domain) E2 = ift.QuadraticEnergy(pos, A2, b2) assert_allclose((E1+E2).value, E1.value+E2.value) assert_allclose((E1-E2).value, E1.value-E2.value) assert_allclose((3.8*E1+E2*4).value, 3.8*E1.value+4*E2.value) assert_allclose((-E1).value, -(E1.value)) with assert_raises(TypeError): E1*E2 with assert_raises(TypeError): E1*2j with assert_raises(TypeError): E1+2 with assert_raises(TypeError): E1-"hello" # with assert_raises(ValueError): # E1+E2.at(2*pos)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!