diff --git a/nifty5/extra/energy_tests.py b/nifty5/extra/energy_tests.py index 693d1ba40e8797be0f52364190f4e4725d139b28..d926d10a3571cde0e1a8a33e2bf97a338ab253ca 100644 --- a/nifty5/extra/energy_tests.py +++ b/nifty5/extra/energy_tests.py @@ -54,7 +54,9 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100): for i in range(50): Emid = E.at(E.position + 0.5*dir) dirder = Emid.gradient.vdot(dir)/dirnorm - if abs((E2.value-val)/dirnorm-dirder) < tol: + t1 = (E2.value-val)/dirnorm + xtol = tol*max(abs(t1), abs(dirder)) + if abs(t1-dirder) < xtol: break dir *= 0.5 dirnorm *= 0.5 diff --git a/test/test_energies/test_map.py b/test/test_energies/test_map.py index 7c7e5f5208f150dfa8968460394abb7aa6f56cf4..c2a8939be6afde9dedb4099350ca246af5ee9907 100644 --- a/test/test_energies/test_map.py +++ b/test/test_energies/test_map.py @@ -61,7 +61,7 @@ class Energy_Tests(unittest.TestCase): energy = ift.library.WienerFilterEnergy( position=s0, d=d, R=R, N=N, S=S, iteration_controller=IC) ift.extra.check_value_gradient_curvature_consistency( - energy, tol=1e-4, ntries=10) + energy, tol=1e-6, ntries=10) @expand(product([ift.GLSpace(15), ift.RGSpace(64, distances=.789), @@ -95,7 +95,7 @@ class Energy_Tests(unittest.TestCase): N=N, S=S) if isinstance(nonlinearity, ift.library.Linear): ift.extra.check_value_gradient_curvature_consistency( - energy, tol=1e-4, ntries=10) + energy, tol=1e-6, ntries=10) else: ift.extra.check_value_gradient_consistency( - energy, tol=1e-4, ntries=10) + energy, tol=1e-6, ntries=10) diff --git a/test/test_energies/test_power.py b/test/test_energies/test_power.py index 99669117f8a7ab42c89a49d7d43cca6cd0a0bc4f..4430a2a96f60893a6687040de59e7f208c6fdd9f 100644 --- a/test/test_energies/test_power.py +++ b/test/test_energies/test_power.py @@ -81,4 +81,4 @@ class Energy_Tests(unittest.TestCase): ht=ht, N=N, samples=10) - ift.extra.check_value_gradient_consistency(energy, tol=1e-5, ntries=10) + ift.extra.check_value_gradient_consistency(energy, tol=1e-6, ntries=10)