diff --git a/nifty5/extra/energy_tests.py b/nifty5/extra/energy_and_model_tests.py similarity index 64% rename from nifty5/extra/energy_tests.py rename to nifty5/extra/energy_and_model_tests.py index 735e2118837b8ece865259e410e522706ffe0f33..34f40b844dfa814488d88a6cf140c8a9ede651de 100644 --- a/nifty5/extra/energy_tests.py +++ b/nifty5/extra/energy_and_model_tests.py @@ -18,15 +18,38 @@ import numpy as np from ..sugar import from_random +from .operator_tests import _assert_allclose +from .. import Energy, Model __all__ = ["check_value_gradient_consistency", "check_value_gradient_curvature_consistency"] +def _get_acceptable_model(M): + # TODO: do for model + val = M.value + if not np.isfinite(val.sum()): + raise ValueError('Initial Model value must be finite') + dir = from_random("normal", M.position.domain) + dirder = M.gradient(dir) + dir *= val/(dirder).norm()*1e-5 + # find a step length that leads to a "reasonable" Model + for i in range(50): + try: + M2 = M.at(M.position+dir) + if np.isfinite(M2.value.sum()) and abs(M2.value.sum()) < 1e20: + break + except FloatingPointError: + pass + dir *= 0.5 + else: + raise ValueError("could not find a reasonable initial step") + return M2 + def _get_acceptable_energy(E): val = E.value if not np.isfinite(val): - raise ValueError + raise ValueError('Initial Energy must be finite') dir = from_random("normal", E.position.domain) dirder = E.gradient.vdot(dir) dir *= np.abs(val)/np.abs(dirder)*1e-5 @@ -46,17 +69,28 @@ def _get_acceptable_energy(E): def check_value_gradient_consistency(E, tol=1e-8, ntries=100): for _ in range(ntries): - E2 = _get_acceptable_energy(E) + if isinstance(E, Energy): + E2 = _get_acceptable_energy(E) + else: + E2 = _get_acceptable_model(E) val = E.value dir = E2.position - E.position # Enext = E2 dirnorm = dir.norm() for i in range(50): Emid = E.at(E.position + 0.5*dir) - dirder = Emid.gradient.vdot(dir)/dirnorm - xtol = tol*Emid.gradient_norm - if abs((E2.value-val)/dirnorm - dirder) < xtol: - break + if isinstance(E, Energy): + dirder = Emid.gradient.vdot(dir)/dirnorm + else: + dirder = Emid.gradient(dir)/dirnorm + if isinstance(E, Model): + #TODO: use relative error here instead + if ((E2.value-val)/dirnorm - dirder).norm()/dirder.norm()< tol: + break + else: + xtol = tol*Emid.gradient_norm + if abs((E2.value-val)/dirnorm - dirder) < xtol: + break dir *= 0.5 dirnorm *= 0.5 E2 = Emid @@ -66,6 +100,8 @@ def check_value_gradient_consistency(E, tol=1e-8, ntries=100): def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100): + if isinstance(E, Model): + raise ValueError('Models have no curvature, thus it cannot be tested.') for _ in range(ntries): E2 = _get_acceptable_energy(E) val = E.value diff --git a/test/test_models/test_model_gradients.py b/test/test_models/test_model_gradients.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d6fd32f14639054dde67f1ee9615bfefd7a323 --- /dev/null +++ b/test/test_models/test_model_gradients.py @@ -0,0 +1,40 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +import unittest +import nifty5 as ift +import numpy as np +from itertools import product +from test.common import expand +from numpy.testing import assert_allclose + + + +class Model_Tests(unittest.TestCase): + @expand(product([ift.GLSpace(15), + ift.RGSpace(64, distances=.789), + ift.RGSpace([32, 32], distances=.789)], + [4, 78, 23])) + def testMul(self, space, seed): + np.random.seed(seed) + #TODO: use random Multifields instead + s1 = ift.full(space, np.random.randn()) + s2 = ift.full(space, np.random.randn()) + s1_var = ift.Variable(ift.MultiField({'s1':s1}))['s1'] + s2_var = ift.Variable(ift.MultiField({'s2':s2}))['s2'] + ift.extra.check_value_gradient_consistency(s1_var*s2_var)