diff --git a/nifty5/extra/energy_and_model_tests.py b/nifty5/extra/energy_and_model_tests.py index abb157037c7c9281be85be4a7482c898ab688ec7..4680e03b667defbbc47beb783034e87fc86edd23 100644 --- a/nifty5/extra/energy_and_model_tests.py +++ b/nifty5/extra/energy_and_model_tests.py @@ -18,7 +18,8 @@ import numpy as np from ..sugar import from_random -from .. import Energy, Model +from ..minimization.energy import Energy +from ..models.model import Model __all__ = ["check_value_gradient_consistency", "check_value_gradient_curvature_consistency"] @@ -74,6 +75,7 @@ def check_value_gradient_consistency(E, tol=1e-8, ntries=100): E2 = _get_acceptable_model(E) val = E.value dir = E2.position - E.position + Enext = E2 dirnorm = dir.norm() for i in range(50): Emid = E.at(E.position + 0.5*dir) @@ -95,6 +97,7 @@ def check_value_gradient_consistency(E, tol=1e-8, ntries=100): E2 = Emid else: raise ValueError("gradient and value seem inconsistent") + E = Enext def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100): @@ -104,6 +107,7 @@ def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100): E2 = _get_acceptable_energy(E) val = E.value dir = E2.position - E.position + Enext = E2 dirnorm = dir.norm() for i in range(50): Emid = E.at(E.position + 0.5*dir) @@ -118,3 +122,4 @@ def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100): E2 = Emid else: raise ValueError("gradient, value and curvature seem inconsistent") + E = Enext