From 54c659a9c92894415e35227a50c202ea1461c3fd Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Tue, 29 May 2018 14:08:03 +0200 Subject: [PATCH] speed up energy tests; test more spaces --- nifty4/extra/energy_tests.py | 9 +++++++-- test/test_energies/test_map.py | 6 ++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/nifty4/extra/energy_tests.py b/nifty4/extra/energy_tests.py index 719f8a8e..a1308ddc 100644 --- a/nifty4/extra/energy_tests.py +++ b/nifty4/extra/energy_tests.py @@ -25,9 +25,12 @@ __all__ = ["check_value_gradient_consistency", def _get_acceptable_energy(E): - if not np.isfinite(E.value): + val = E.value + if not np.isfinite(val): raise ValueError dir = from_random("normal", E.position.domain) + dirder = E.gradient.vdot(dir) + dir *= np.abs(val)/np.abs(dirder)*1e-5 # find a step length that leads to a "reasonable" energy for i in range(50): try: @@ -45,12 +48,14 @@ def _get_acceptable_energy(E): def check_value_gradient_consistency(E, tol=1e-6, ntries=100): for _ in range(ntries): E2 = _get_acceptable_energy(E) + val = E.value dir = E2.position - E.position Enext = E2 dirnorm = dir.norm() dirder = E.gradient.vdot(dir)/dirnorm for i in range(50): - if abs((E2.value-E.value)/dirnorm-dirder) < tol: + if abs((E2.value-val)/dirnorm-dirder) < tol: + print i, dirnorm, dirder, E2.value-val break dir *= 0.5 dirnorm *= 0.5 diff --git a/test/test_energies/test_map.py b/test/test_energies/test_map.py index 311feb18..fb298f90 100644 --- a/test/test_energies/test_map.py +++ b/test/test_energies/test_map.py @@ -29,7 +29,8 @@ def _flat_PS(k): class Energy_Tests(unittest.TestCase): - @expand(product([ift.RGSpace(64, distances=.789), + @expand(product([ift.GLSpace(15), + ift.RGSpace(64, distances=.789), ift.RGSpace([32, 32], distances=.789)], [4, 78, 23])) def testLinearMap(self, space, seed): @@ -63,7 +64,8 @@ class Energy_Tests(unittest.TestCase): ift.extra.check_value_gradient_curvature_consistency( energy, tol=1e-4, ntries=10) - @expand(product([ift.RGSpace(64, distances=.789), + @expand(product([ift.GLSpace(15), + ift.RGSpace(64, distances=.789), ift.RGSpace([32, 32], distances=.789)], [ift.library.Tanh, ift.library.Exponential, ift.library.Linear], -- GitLab