Commit 4cf96788 authored by Philipp Arras's avatar Philipp Arras
Browse files

More tests, remove plotting from tests

parent 75f0625d
Pipeline #24173 failed with stage
in 3 minutes and 59 seconds
......@@ -28,7 +28,7 @@ from numpy.testing import assert_allclose
# TODO Set tolerances and eps to reasonable values
class Map_Energy_Tests(unittest.TestCase):
class Energy_Tests(unittest.TestCase):
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
......@@ -168,6 +168,8 @@ class Map_Energy_Tests(unittest.TestCase):
tol = 1e-4
assert_allclose(a, b, rtol=tol, atol=tol)
class Curvature_Tests(unittest.TestCase):
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
......
......@@ -27,7 +27,7 @@ from numpy.testing import assert_allclose
# TODO Add also other space types
class Power_Energy_Tests(unittest.TestCase):
class Energy_Tests(unittest.TestCase):
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[132, 42, 3]))
......@@ -53,10 +53,6 @@ class Power_Energy_Tests(unittest.TestCase):
Instrument = ift.DiagonalOperator(diag)
R = Instrument * ht
d = R(s) + n
ift.plot(d, name='d.png')
ift.plot(ht(s), name='s.png')
ift.plot(n, name='n.png')
ift.plot(pspec, name='pspec.png')
direction = ift.Field.from_random('normal', pspace)
direction /= np.sqrt(direction.var())
......@@ -171,3 +167,138 @@ class Power_Energy_Tests(unittest.TestCase):
b = energy0.gradient.vdot(direction)
tol = 1e-4
assert_allclose(a, b, rtol=tol, atol=tol)
class Curvature_Tests(unittest.TestCase):
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[132, 42, 3]))
def testLinearPowerCurvature(self, space, seed):
np.random.seed(seed)
dim = len(space.shape)
hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, space)
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=True)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
P = ift.PowerProjectionOperator(domain=hspace, power_space=pspace)
xi = ift.Field.from_random(domain=hspace, random_type='normal')
# TODO Power spectrum abhängig von Anzahl der Pixel
def pspec(k): return 64 / (1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
tau0 = ift.log(pspec)
A = P.adjoint_times(ift.sqrt(pspec))
n = ift.Field.from_random(domain=space, random_type='normal', std=.01)
N = ift.DiagonalOperator(n**2)
s = xi * A
diag = ift.Field.ones(space)
Instrument = ift.DiagonalOperator(diag)
R = Instrument * ht
d = R(s) + n
direction = ift.Field.from_random('normal', pspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
tau1 = tau0 + eps * direction
IC = ift.GradientNormController(
name='IC',
verbose=False,
iteration_limit=100,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.)
D = ift.library.WienerFilterEnergy(position=s, d=d, R=R, N=N, S=S,
inverter=inverter).curvature
w = ift.Field.zeros_like(tau0)
Nsamples = 10
for i in range(Nsamples):
sample = D.generate_posterior_sample() + s
w += P(abs(sample)**2)
w /= Nsamples
energy0 = ift.library.CriticalPowerEnergy(
position=tau0, m=s, inverter=inverter, w=w)
gradient0 = energy0.gradient
gradient1 = energy0.at(tau1).gradient
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-5
assert_allclose(a.val, b.val, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[ift.library.Exponential, ift.library.Linear],
[132, 42, 3]))
def testNonlinearPowerCurvature(self, space, nonlinearity, seed):
np.random.seed(seed)
f = nonlinearity()
dim = len(space.shape)
hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, space)
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=True)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
P = ift.PowerProjectionOperator(domain=hspace, power_space=pspace)
xi = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k): return 1 / (1 + k**2)**dim
tau0 = ift.PS_field(pspace, pspec)
A = P.adjoint_times(ift.sqrt(tau0))
n = ift.Field.from_random(domain=space, random_type='normal')
s = ht(xi * A)
diag = ift.Field.ones(space) * 10
R = ift.DiagonalOperator(diag)
diag = ift.Field.ones(space)
N = ift.DiagonalOperator(diag)
d = R(f(s)) + n
direction = ift.Field.from_random('normal', pspace)
direction /= np.sqrt(direction.var())
eps = 1e-7
tau1 = tau0 + eps * direction
IC = ift.GradientNormController(
name='IC',
verbose=False,
iteration_limit=100,
tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.)
D = ift.library.NonlinearWienerFilterEnergy(
position=xi,
d=d,
Instrument=R,
nonlinearity=f,
power=A,
N=N,
S=S,
ht=ht,
inverter=inverter).curvature
Nsamples = 10
sample_list = [D.generate_posterior_sample() + xi for _ in range(Nsamples)]
energy0 = ift.library.NonlinearPowerEnergy(
position=tau0,
d=d,
m=xi,
D=D,
Instrument=R,
Projection=P,
nonlinearity=f,
ht=ht,
N=N,
sample_list=sample_list)
gradient0 = energy0.gradient
gradient1 = energy0.at(tau1).gradient
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-1
assert_allclose(a.val, b.val, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment