Commit 939a0fd1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

stricter energy tests

parent 6947fe75
Pipeline #31559 failed with stages
in 1 minute and 25 seconds
...@@ -44,7 +44,7 @@ def _get_acceptable_energy(E): ...@@ -44,7 +44,7 @@ def _get_acceptable_energy(E):
return E2 return E2
def check_value_gradient_consistency(E, tol=1e-6, ntries=100): def check_value_gradient_consistency(E, tol=1e-10, ntries=100):
for _ in range(ntries): for _ in range(ntries):
E2 = _get_acceptable_energy(E) E2 = _get_acceptable_energy(E)
val = E.value val = E.value
...@@ -54,9 +54,8 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100): ...@@ -54,9 +54,8 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
for i in range(50): for i in range(50):
Emid = E.at(E.position + 0.5*dir) Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm dirder = Emid.gradient.vdot(dir)/dirnorm
t1 = (E2.value-val)/dirnorm xtol = tol*Emid.gradient_norm
xtol = tol*max(abs(t1), abs(dirder)) if abs((E2.value-val)/dirnorm - dirder) < xtol:
if abs(t1-dirder) < xtol:
break break
dir *= 0.5 dir *= 0.5
dirnorm *= 0.5 dirnorm *= 0.5
...@@ -66,7 +65,7 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100): ...@@ -66,7 +65,7 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
# E = Enext # E = Enext
def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100): def check_value_gradient_curvature_consistency(E, tol=1e-10, ntries=100):
for _ in range(ntries): for _ in range(ntries):
E2 = _get_acceptable_energy(E) E2 = _get_acceptable_energy(E)
val = E.value val = E.value
...@@ -77,8 +76,9 @@ def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100): ...@@ -77,8 +76,9 @@ def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100):
Emid = E.at(E.position + 0.5*dir) Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm dirder = Emid.gradient.vdot(dir)/dirnorm
dgrad = Emid.curvature(dir)/dirnorm dgrad = Emid.curvature(dir)/dirnorm
if abs((E2.value-val)/dirnorm-dirder) < tol and \ xtol = tol*Emid.gradient_norm
(abs((E2.gradient-E.gradient)/dirnorm-dgrad) < tol).all(): if abs((E2.value-val)/dirnorm - dirder) < xtol and \
(abs((E2.gradient-E.gradient)/dirnorm-dgrad) < xtol).all():
break break
dir *= 0.5 dir *= 0.5
dirnorm *= 0.5 dirnorm *= 0.5
......
...@@ -61,7 +61,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -61,7 +61,7 @@ class Energy_Tests(unittest.TestCase):
energy = ift.library.WienerFilterEnergy( energy = ift.library.WienerFilterEnergy(
position=s0, d=d, R=R, N=N, S=S, iteration_controller=IC) position=s0, d=d, R=R, N=N, S=S, iteration_controller=IC)
ift.extra.check_value_gradient_curvature_consistency( ift.extra.check_value_gradient_curvature_consistency(
energy, tol=1e-6, ntries=10) energy, ntries=10)
@expand(product([ift.GLSpace(15), @expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789), ift.RGSpace(64, distances=.789),
...@@ -95,7 +95,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -95,7 +95,7 @@ class Energy_Tests(unittest.TestCase):
N=N, S=S) N=N, S=S)
if isinstance(nonlinearity, ift.library.Linear): if isinstance(nonlinearity, ift.library.Linear):
ift.extra.check_value_gradient_curvature_consistency( ift.extra.check_value_gradient_curvature_consistency(
energy, tol=1e-6, ntries=10) energy, ntries=10)
else: else:
ift.extra.check_value_gradient_consistency( ift.extra.check_value_gradient_consistency(
energy, tol=1e-6, ntries=10) energy, ntries=10)
...@@ -82,4 +82,4 @@ class Noise_Energy_Tests(unittest.TestCase): ...@@ -82,4 +82,4 @@ class Noise_Energy_Tests(unittest.TestCase):
for _ in range(10)] for _ in range(10)]
energy = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list) energy = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
ift.extra.check_value_gradient_consistency(energy, tol=1e-6, ntries=10) ift.extra.check_value_gradient_consistency(energy, ntries=10)
...@@ -81,4 +81,4 @@ class Energy_Tests(unittest.TestCase): ...@@ -81,4 +81,4 @@ class Energy_Tests(unittest.TestCase):
ht=ht, ht=ht,
N=N, N=N,
samples=10) samples=10)
ift.extra.check_value_gradient_consistency(energy, tol=1e-6, ntries=10) ift.extra.check_value_gradient_consistency(energy, ntries=10)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment