diff --git a/nifty5/minimization/energy.py b/nifty5/minimization/energy.py index de48cc4ff1bdb9de0bf08ae08e497d3f6fd6abb1..224c6e7c2df67c9d55796af707c5ec5dd3f7102d 100644 --- a/nifty5/minimization/energy.py +++ b/nifty5/minimization/energy.py @@ -131,24 +131,24 @@ class Energy(NiftyMetaBase()): def __mul__(self, factor): from .energy_sum import EnergySum - if not isinstance(factor, (float, int)): - raise TypeError("Factor must be a real-valued scalar") - return EnergySum.make([self], [factor]) + if isinstance(factor, (float, int)): + return EnergySum.make([self], [factor]) + return NotImplemented def __rmul__(self, factor): return self.__mul__(factor) def __add__(self, other): from .energy_sum import EnergySum - if not isinstance(other, Energy): - raise TypeError("can only add Energies to Energies") - return EnergySum.make([self, other]) + if isinstance(other, Energy): + return EnergySum.make([self, other]) + return NotImplemented def __sub__(self, other): from .energy_sum import EnergySum - if not isinstance(other, Energy): - raise TypeError("can only subtract Energies from Energies") - return EnergySum.make([self, other], [1., -1.]) + if isinstance(other, Energy): + return EnergySum.make([self, other], [1., -1.]) + return NotImplemented def __neg__(self): from .energy_sum import EnergySum diff --git a/nifty5/minimization/energy_sum.py b/nifty5/minimization/energy_sum.py index f9886b4e85b30b8076d63e057e9dd5243769c429..5e288990edf45093beab1f20140fc5eb8ad31a95 100644 --- a/nifty5/minimization/energy_sum.py +++ b/nifty5/minimization/energy_sum.py @@ -16,8 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from .energy import Energy from ..utilities import memo +from .energy import Energy class EnergySum(Energy): diff --git a/test/test_energies/test_energy_sum.py b/test/test_energies/test_energy_sum.py index 5d5fce30cab9bb2323a7cd2cc1724a229befdf1e..a63ce9540f2f1500032f10b5ded0e88b990b2c19 100644 --- a/test/test_energies/test_energy_sum.py +++ b/test/test_energies/test_energy_sum.py @@ -16,11 +16,12 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -import nifty5 as ift -import numpy as np import unittest from itertools import product from test.common import expand + +import nifty5 as ift +import numpy as np from numpy.testing import assert_allclose, assert_raises