Commit 69075882 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add test for power energy

parent e07a9ff6
Pipeline #24098 failed with stage
in 3 minutes and 57 seconds
......@@ -23,7 +23,7 @@ from .response_operators import LinearizedSignalResponse
class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S, sunit,
def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S, sunit=1.,
inverter=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d
......
......@@ -74,3 +74,48 @@ class Energy_Tests(unittest.TestCase):
b = energy0.gradient.vdot(direction)
tol = 1e-2
assert_allclose(a, b, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[ift.library.Exponential, ift.library.Linear]))
def testNonlinearPower(self, space, nonlinearity):
f = nonlinearity()
dim = len(space.shape)
fft = ift.FFTOperator(space)
hspace = fft.target[0]
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=False)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
P = ift.PowerProjectionOperator(domain=hspace, power_space=pspace)
xi = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k): return 1 / (1 + k**2)**dim
tau0 = ift.PS_field(pspace, pspec)
A = P.adjoint_times(ift.sqrt(tau0))
n = ift.Field.from_random(domain=space, random_type='normal')
s = fft.inverse_times(xi * A)
diag = ift.Field.ones(space) * 10
R = ift.DiagonalOperator(diag)
diag = ift.Field.ones(space)
N = ift.DiagonalOperator(diag)
d = R(f(s)) + n
direction = ift.Field.from_random('normal', pspace)
direction /= np.sqrt(direction.var())
eps = 1e-10
tau1 = tau0 + eps * direction
IC = ift.GradientNormController(name='IC', verbose=False, iteration_limit=100, tol_abs_gradnorm=1e-5)
inverter = ift.ConjugateGradient(IC)
S = ift.create_power_operator(hspace, power_spectrum=lambda k: 1.)
D = ift.library.NonlinearWienerFilterEnergy(position=xi, d=d, Instrument=R, nonlinearity=f, FFT=fft, power=A, N=N, S=S, inverter=inverter).curvature
energy0 = ift.library.NonlinearPowerEnergy(
position=tau0, d=d, m=xi, D=D, Instrument=R, Projection=P, nonlinearity=f, FFT=fft, N=N, inverter=inverter)
energy1 = ift.library.NonlinearPowerEnergy(
position=tau1, d=d, m=xi, D=D, Instrument=R, Projection=P, nonlinearity=f, FFT=fft, N=N, inverter=inverter)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-2
assert_allclose(a, b, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment