Skip to content
Snippets Groups Projects
Commit 1c78f918 authored by Philipp Arras's avatar Philipp Arras
Browse files

Energy tests ready to merge

Sadly, it is not possible to test the curvatures other than in the linear case.
We have failed. The reason is that the curvature is approximated in a way that
it is positive definite. For details see Knollmüller and Ensslin 2017.
parent d3a84b2d
Branches
Tags
1 merge request!215Energy tests
Pipeline #
...@@ -144,11 +144,11 @@ class Energy_Tests(unittest.TestCase): ...@@ -144,11 +144,11 @@ class Energy_Tests(unittest.TestCase):
S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.) S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.)
energy0 = ift.library.NonlinearWienerFilterEnergy( energy0 = ift.library.NonlinearWienerFilterEnergy(
position=xi0, d=d, Instrument=R, nonlinearity=f, ht=ht, power=A, N=N, S=S) position=xi0, d=d, Instrument=R, nonlinearity=f, ht=ht, power=A, N=N, S=S)
energy1 = energy0.at(xi0) energy1 = energy0.at(xi1)
a = (energy1.value - energy0.value) / eps a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction) b = energy0.gradient.vdot(direction)
tol = 1e-4 tol = 1e-5
assert_allclose(a, b, rtol=tol, atol=tol) assert_allclose(a, b, rtol=tol, atol=tol)
...@@ -197,55 +197,10 @@ class Curvature_Tests(unittest.TestCase): ...@@ -197,55 +197,10 @@ class Curvature_Tests(unittest.TestCase):
tol = 1e-7 tol = 1e-7
assert_allclose(a.val, b.val, rtol=tol, atol=tol) assert_allclose(a.val, b.val, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testLognormalMapCurvature(self, space, seed):
np.random.seed(seed)
dim = len(space.shape)
hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, target=space)
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=False)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
P = ift.PowerProjectionOperator(domain=hspace, power_space=pspace)
xi0 = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k): return 1 / (1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
A = P.adjoint_times(ift.sqrt(pspec))
n = ift.Field.from_random(domain=space, random_type='normal')
sh0 = xi0 * A
s = ht(sh0)
Instrument = ift.ScalingOperator(10., space)
R = Instrument * ht
N = ift.ScalingOperator(1., space)
d = Instrument(ift.exp(s)) + n
direction = ift.Field.from_random('normal', hspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
sh1 = sh0 + eps * direction
IC = ift.GradientNormController(
iteration_limit=100,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.)
energy0 = ift.library.LogNormalWienerFilterEnergy(
position=sh0, d=d, R=R, N=N, S=S, inverter=inverter)
gradient0 = energy0.gradient
gradient1 = energy0.at(sh1).gradient
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-3
assert_allclose(a.val, b.val, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789), @expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)], ift.RGSpace([32, 32], distances=.789)],
[ift.library.Exponential, ift.library.Linear], [ift.library.Linear], # Only linear case due to approximation of Hessian in the case of nontrivial nonlinearities.
[4, 78, 23])) [4, 78, 23]))
def testNonlinearMapCurvature(self, space, nonlinearity, seed): def testNonlinearMapCurvature(self, space, nonlinearity, seed):
np.random.seed(seed) np.random.seed(seed)
...@@ -293,5 +248,7 @@ class Curvature_Tests(unittest.TestCase): ...@@ -293,5 +248,7 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction) b = energy0.curvature(direction)
tol = 1e-3 print(a.vdot(a))
print(b.vdot(b))
tol = 1e-7
assert_allclose(a.val, b.val, rtol=tol, atol=tol) assert_allclose(a.val, b.val, rtol=tol, atol=tol)
...@@ -59,15 +59,13 @@ class Energy_Tests(unittest.TestCase): ...@@ -59,15 +59,13 @@ class Energy_Tests(unittest.TestCase):
tol_abs_gradnorm=1e-5) tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(IC) inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.) S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1./(1+k**2))
D = ift.library.WienerFilterEnergy(position=s, d=d, R=R, N=N, S=S, D = ift.library.WienerFilterEnergy(position=s, d=d, R=R, N=N, S=S,
inverter=inverter).curvature inverter=inverter).curvature
w = ift.Field.zeros_like(tau0)
energy0 = ift.library.CriticalPowerEnergy( energy0 = ift.library.CriticalPowerEnergy(
position=tau0, m=s, inverter=inverter, w=w, samples=10) position=tau0, m=s, inverter=inverter, D=D, samples=10, smoothness_prior=1.)
energy1 = energy0.at(tau1) energy1 = energy0.at(tau1)
a = (energy1.value - energy0.value) / eps a = (energy1.value - energy0.value) / eps
...@@ -192,71 +190,3 @@ class Curvature_Tests(unittest.TestCase): ...@@ -192,71 +190,3 @@ class Curvature_Tests(unittest.TestCase):
b = energy0.curvature(direction) b = energy0.curvature(direction)
tol = 1e-5 tol = 1e-5
assert_allclose(a.val, b.val, rtol=tol, atol=tol) assert_allclose(a.val, b.val, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[ift.library.Exponential, ift.library.Linear],
[132, 42, 3]))
def testNonlinearPowerCurvature(self, space, nonlinearity, seed):
np.random.seed(seed)
f = nonlinearity()
dim = len(space.shape)
hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, space)
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=True)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
P = ift.PowerProjectionOperator(domain=hspace, power_space=pspace)
xi = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k): return 1 / (1 + k**2)**dim
tau0 = ift.PS_field(pspace, pspec)
A = P.adjoint_times(ift.sqrt(tau0))
n = ift.Field.from_random(domain=space, random_type='normal')
s = ht(xi * A)
diag = ift.Field.ones(space) * 10
R = ift.DiagonalOperator(diag)
diag = ift.Field.ones(space)
N = ift.DiagonalOperator(diag)
d = R(f(s)) + n
direction = ift.Field.from_random('normal', pspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
tau1 = tau0 + eps * direction
IC = ift.GradientNormController(
iteration_limit=100,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.)
D = ift.library.NonlinearWienerFilterEnergy(
position=xi,
d=d,
Instrument=R,
nonlinearity=f,
power=A,
N=N,
S=S,
ht=ht,
inverter=inverter).curvature
energy0 = ift.library.NonlinearPowerEnergy(
position=tau0,
d=d,
xi=xi,
D=D,
Instrument=R,
Projection=P,
nonlinearity=f,
ht=ht,
N=N,
samples=10)
gradient0 = energy0.gradient
gradient1 = energy0.at(tau1).gradient
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-3
assert_allclose(a.val, b.val, rtol=tol, atol=tol)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment