Commit 17d86163 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak energy tests

parent e3bcf635
Pipeline #26739 passed with stage
in 11 minutes and 51 seconds
from .operator_tests import consistency_check from .operator_tests import consistency_check
from .energy_tests import check_value_gradient_consistency from .energy_tests import *
...@@ -19,35 +19,63 @@ ...@@ -19,35 +19,63 @@
import numpy as np import numpy as np
from ..field import Field from ..field import Field
__all__ = ["check_value_gradient_consistency"] __all__ = ["check_value_gradient_consistency",
"check_value_gradient_curvature_consistency"]
def check_value_gradient_consistency(E, tol=1e-6, ntries=100): def _get_acceptable_energy(E):
if not np.isfinite(E.value): if not np.isfinite(E.value):
raise ValueError raise ValueError
dir = Field.from_random("normal", E.position.domain)
# find a step length that leads to a "reasonable" energy
for i in range(50):
try:
E2 = E.at(E.position+dir)
if np.isfinite(E2.value) and abs(E2.value) < 1e20:
break
except FloatingPointError:
pass
dir *= 0.5
else:
raise ValueError("could not find a reasonable initial step")
return E2
def check_value_gradient_consistency(E, tol=1e-6, ntries=100):
for _ in range(ntries): for _ in range(ntries):
dir = Field.from_random("normal", E.position.domain) E2 = _get_acceptable_energy(E)
# find a step length that leads to a "reasonable" energy dir = E2.position - E.position
Enext = E2
dirnorm = dir.norm()
dirder = E.gradient.vdot(dir)/dirnorm
for i in range(50): for i in range(50):
try: print(abs((E2.value-E.value)/dirnorm-dirder))
E2 = E.at(E.position+dir) if abs((E2.value-E.value)/dirnorm-dirder) < tol:
if np.isfinite(E2.value) and abs(E2.value) < 1e20: break
break
except FloatingPointError:
pass
dir *= 0.5 dir *= 0.5
dirnorm *= 0.5
E2 = E2.at(E.position+dir)
else: else:
raise ValueError("could not find a reasonable initial step") raise ValueError("gradient and value seem inconsistent")
# E = Enext
def check_value_gradient_curvature_consistency(E, tol=1e-6, ntries=100):
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
dir = E2.position - E.position
Enext = E2 Enext = E2
dirder = E.gradient.vdot(dir) dirnorm = dir.norm()
dirder = E.gradient.vdot(dir)/dirnorm
dgrad = E.curvature(dir)/dirnorm
for i in range(50): for i in range(50):
Ediff = E2.value - E.value gdiff = E2.gradient - E.gradient
eps = 1e-10*max(abs(E.value), abs(E2.value)) if abs((E2.value-E.value)/dirnorm-dirder) < tol and \
if abs(Ediff-dirder) < max([tol*abs(Ediff), tol*abs(dirder), eps]): (abs((E2.gradient-E.gradient)/dirnorm-dgrad) < tol).all():
break break
dir *= 0.5 dir *= 0.5
dirder *= 0.5 dirnorm *= 0.5
E2 = E2.at(E.position+dir) E2 = E2.at(E.position+dir)
else: else:
raise ValueError("gradient and value seem inconsistent") raise ValueError("gradient, value and curvature seem inconsistent")
E = Enext # E = Enext
...@@ -52,11 +52,6 @@ class Energy_Tests(unittest.TestCase): ...@@ -52,11 +52,6 @@ class Energy_Tests(unittest.TestCase):
N = ift.ScalingOperator(1., space) N = ift.ScalingOperator(1., space)
d = R(s0) + n d = R(s0) + n
direction = ift.Field.from_random('normal', hspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
s1 = s0 + eps * direction
IC = ift.GradientNormController( IC = ift.GradientNormController(
iteration_limit=100, iteration_limit=100,
tol_abs_gradnorm=1e-5) tol_abs_gradnorm=1e-5)
...@@ -65,7 +60,8 @@ class Energy_Tests(unittest.TestCase): ...@@ -65,7 +60,8 @@ class Energy_Tests(unittest.TestCase):
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS) S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy = ift.library.WienerFilterEnergy( energy = ift.library.WienerFilterEnergy(
position=s0, d=d, R=R, N=N, S=S, inverter=inverter) position=s0, d=d, R=R, N=N, S=S, inverter=inverter)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10) ift.extra.check_value_gradient_curvature_consistency(
energy, tol=1e-4, ntries=10)
@expand(product([ift.RGSpace(64, distances=.789), @expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)], ift.RGSpace([32, 32], distances=.789)],
...@@ -92,120 +88,13 @@ class Energy_Tests(unittest.TestCase): ...@@ -92,120 +88,13 @@ class Energy_Tests(unittest.TestCase):
N = ift.ScalingOperator(1., space) N = ift.ScalingOperator(1., space)
d = R(f(s)) + n d = R(f(s)) + n
direction = ift.Field.from_random('normal', hspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
xi1 = xi0 + eps * direction
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS) S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy = ift.library.NonlinearWienerFilterEnergy( energy = ift.library.NonlinearWienerFilterEnergy(
position=xi0, d=d, Instrument=R, nonlinearity=f, ht=ht, power=A, position=xi0, d=d, Instrument=R, nonlinearity=f, ht=ht, power=A,
N=N, S=S) N=N, S=S)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10) if isinstance(nonlinearity, ift.library.Linear):
ift.extra.check_value_gradient_curvature_consistency(
energy, tol=1e-4, ntries=10)
class Curvature_Tests(unittest.TestCase): else:
# Note: It is only possible to test linear curvatures since the non-linear ift.extra.check_value_gradient_consistency(
# curvatures are not the exact second derivative but only a part of it. One energy, tol=1e-4, ntries=10)
# term is neglected which would render the second derivative non-positive
# definite.
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testLinearMapCurvature(self, space, seed):
np.random.seed(seed)
dim = len(space.shape)
hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, target=space)
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=False)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
Dist = ift.PowerDistributor(target=hspace, power_space=pspace)
xi0 = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k): return 1 / (1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
A = Dist(ift.sqrt(pspec))
n = ift.Field.from_random(domain=space, random_type='normal')
s0 = xi0 * A
Instrument = ift.ScalingOperator(10., space)
R = Instrument * ht
N = ift.ScalingOperator(1., space)
d = R(s0) + n
direction = ift.Field.from_random('normal', hspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
s1 = s0 + eps * direction
IC = ift.GradientNormController(
iteration_limit=100,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy0 = ift.library.WienerFilterEnergy(
position=s0, d=d, R=R, N=N, S=S, inverter=inverter)
gradient0 = energy0.gradient
gradient1 = energy0.at(s1).gradient
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-7
assert_allclose(a.to_global_data(), b.to_global_data(),
rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
# Only linear case due to approximation of Hessian in the
# case of nontrivial nonlinearities.
[ift.library.Linear],
[4, 78, 23]))
def testNonlinearMapCurvature(self, space, nonlinearity, seed):
np.random.seed(seed)
f = nonlinearity()
dim = len(space.shape)
hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, target=space)
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=False)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
Dist = ift.PowerDistributor(target=hspace, power_space=pspace)
xi0 = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k): return 1 / (1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
A = Dist(ift.sqrt(pspec))
n = ift.Field.from_random(domain=space, random_type='normal')
s = ht(xi0 * A)
R = ift.ScalingOperator(10., space)
N = ift.ScalingOperator(1., space)
d = R(f(s)) + n
direction = ift.Field.from_random('normal', hspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
xi1 = xi0 + eps * direction
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
IC = ift.GradientNormController(
iteration_limit=500,
tol_abs_gradnorm=1e-7)
inverter = ift.ConjugateGradient(IC)
energy0 = ift.library.NonlinearWienerFilterEnergy(
position=xi0,
d=d,
Instrument=R,
nonlinearity=f,
ht=ht,
power=A,
N=N,
S=S,
inverter=inverter)
gradient0 = energy0.gradient
gradient1 = energy0.at(xi1).gradient
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-7
assert_allclose(a.to_global_data(), b.to_global_data(),
rtol=tol, atol=tol)
...@@ -62,11 +62,6 @@ class Noise_Energy_Tests(unittest.TestCase): ...@@ -62,11 +62,6 @@ class Noise_Energy_Tests(unittest.TestCase):
alpha = ift.Field.full(d.domain, 2.) alpha = ift.Field.full(d.domain, 2.)
q = ift.Field.full(d.domain, 1e-5) q = ift.Field.full(d.domain, 1e-5)
direction = ift.Field.from_random('normal', d.domain)
direction /= np.sqrt(direction.var())
eps = 1e-8
eta1 = eta0 + eps * direction
IC = ift.GradientNormController( IC = ift.GradientNormController(
iteration_limit=100, iteration_limit=100,
tol_abs_gradnorm=1e-5) tol_abs_gradnorm=1e-5)
...@@ -88,4 +83,4 @@ class Noise_Energy_Tests(unittest.TestCase): ...@@ -88,4 +83,4 @@ class Noise_Energy_Tests(unittest.TestCase):
for _ in range(10)] for _ in range(10)]
energy = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list) energy = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10) ift.extra.check_value_gradient_consistency(energy, tol=1e-6, ntries=10)
...@@ -54,11 +54,6 @@ class Energy_Tests(unittest.TestCase): ...@@ -54,11 +54,6 @@ class Energy_Tests(unittest.TestCase):
N = ift.DiagonalOperator(diag) N = ift.DiagonalOperator(diag)
d = R(f(s)) + n d = R(f(s)) + n
direction = ift.Field.from_random('normal', pspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
tau1 = tau0 + eps * direction
IC = ift.GradientNormController( IC = ift.GradientNormController(
iteration_limit=100, iteration_limit=100,
tol_abs_gradnorm=1e-5) tol_abs_gradnorm=1e-5)
...@@ -87,4 +82,4 @@ class Energy_Tests(unittest.TestCase): ...@@ -87,4 +82,4 @@ class Energy_Tests(unittest.TestCase):
ht=ht, ht=ht,
N=N, N=N,
samples=10) samples=10)
ift.extra.check_value_gradient_consistency(energy, tol=1e-8, ntries=10) ift.extra.check_value_gradient_consistency(energy, tol=1e-5, ntries=10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment