Commit 939a0fd1 authored by Martin Reinecke's avatar Martin Reinecke

stricter energy tests

parent 6947fe75
Pipeline #31559 failed with stages
in 1 minute and 25 seconds
......@@ -44,7 +44,7 @@ def _get_acceptable_energy(E):
return E2
def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
def check_value_gradient_consistency(E, tol=1e-10, ntries=100):
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
val = E.value
......@@ -54,9 +54,8 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
for i in range(50):
Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm
t1 = (E2.value-val)/dirnorm
xtol = tol*max(abs(t1), abs(dirder))
if abs(t1-dirder) < xtol:
xtol = tol*Emid.gradient_norm
if abs((E2.value-val)/dirnorm - dirder) < xtol:
break
dir *= 0.5
dirnorm *= 0.5
......@@ -66,7 +65,7 @@ def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
# E = Enext
def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100):
def check_value_gradient_curvature_consistency(E, tol=1e-10, ntries=100):
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
val = E.value
......@@ -77,8 +76,9 @@ def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100):
Emid = E.at(E.position + 0.5*dir)
dirder = Emid.gradient.vdot(dir)/dirnorm
dgrad = Emid.curvature(dir)/dirnorm
if abs((E2.value-val)/dirnorm-dirder) < tol and \
(abs((E2.gradient-E.gradient)/dirnorm-dgrad) < tol).all():
xtol = tol*Emid.gradient_norm
if abs((E2.value-val)/dirnorm - dirder) < xtol and \
(abs((E2.gradient-E.gradient)/dirnorm-dgrad) < xtol).all():
break
dir *= 0.5
dirnorm *= 0.5
......
......@@ -61,7 +61,7 @@ class Energy_Tests(unittest.TestCase):
energy = ift.library.WienerFilterEnergy(
position=s0, d=d, R=R, N=N, S=S, iteration_controller=IC)
ift.extra.check_value_gradient_curvature_consistency(
energy, tol=1e-6, ntries=10)
energy, ntries=10)
@expand(product([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
......@@ -95,7 +95,7 @@ class Energy_Tests(unittest.TestCase):
N=N, S=S)
if isinstance(nonlinearity, ift.library.Linear):
ift.extra.check_value_gradient_curvature_consistency(
energy, tol=1e-6, ntries=10)
energy, ntries=10)
else:
ift.extra.check_value_gradient_consistency(
energy, tol=1e-6, ntries=10)
energy, ntries=10)
......@@ -82,4 +82,4 @@ class Noise_Energy_Tests(unittest.TestCase):
for _ in range(10)]
energy = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
ift.extra.check_value_gradient_consistency(energy, tol=1e-6, ntries=10)
ift.extra.check_value_gradient_consistency(energy, ntries=10)
......@@ -81,4 +81,4 @@ class Energy_Tests(unittest.TestCase):
ht=ht,
N=N,
samples=10)
ift.extra.check_value_gradient_consistency(energy, tol=1e-6, ntries=10)
ift.extra.check_value_gradient_consistency(energy, ntries=10)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment