Commit 43b247bb authored by Philipp Arras's avatar Philipp Arras
Browse files

LinearPowerEnergy test now passing

parent 72f1b059
Pipeline #24142 failed with stage
in 3 minutes and 58 seconds
......@@ -35,7 +35,7 @@ class CriticalPowerEnergy(Energy):
Parameters
----------
position : Field,
The current position of this energy.
The current position of this energy. (Logarithm of power spectrum)
m : Field,
The map whose power spectrum has to be inferred
D : EndomorphicOperator,
......@@ -101,8 +101,8 @@ class CriticalPowerEnergy(Energy):
self._theta = exp(-self.position) * (self.q + self._w*0.5)
Tt = self.T(self.position)
energy = self._theta.integrate()
energy += self.position.integrate()*(self.alpha-0.5)
energy = self._theta.vdot(Field.ones_like(self._theta))
energy += self.position.vdot(Field.ones_like(self.position)) *(self.alpha-0.5)
energy += 0.5*self.position.vdot(Tt)
self._value = energy.real
......
......@@ -39,21 +39,26 @@ class Power_Energy_Tests(unittest.TestCase):
P = ift.PowerProjectionOperator(domain=hspace, power_space=pspace)
xi = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k): return 1 / (1 + k**2)**dim
tau0 = ift.PS_field(pspace, pspec)
A = P.adjoint_times(ift.sqrt(tau0))
n = ift.Field.from_random(domain=space, random_type='normal')
# TODO Power spectrum abhängig von Anzahl der Pixel
def pspec(k): return 64 / (1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
tau0 = ift.log(pspec)
A = P.adjoint_times(ift.sqrt(pspec))
n = ift.Field.from_random(domain=space, random_type='normal', std=.01)
N = ift.DiagonalOperator(n**2)
s = xi * A
diag = ift.Field.ones(space) * 10
diag = ift.Field.ones(space)
Instrument = ift.DiagonalOperator(diag)
R = Instrument * ht
diag = ift.Field.ones(space)
N = ift.DiagonalOperator(diag)
d = R(s) + n
ift.plot(d, name='d.png')
ift.plot(ht(s), name='s.png')
ift.plot(n, name='n.png')
ift.plot(pspec, name='pspec.png')
direction = ift.Field.from_random('normal', pspace)
direction /= np.sqrt(direction.var())
eps = 1e-10
eps = 1e-7
tau1 = tau0 + eps * direction
IC = ift.GradientNormController(
......@@ -68,14 +73,21 @@ class Power_Energy_Tests(unittest.TestCase):
D = ift.library.WienerFilterEnergy(position=s, d=d, R=R, N=N, S=S,
inverter=inverter).curvature
w = ift.Field.zeros_like(tau0)
Nsamples = 10
for i in range(Nsamples):
sample = D.generate_posterior_sample() + s
w += P(abs(sample)**2)
w /= Nsamples
energy0 = ift.library.CriticalPowerEnergy(
position=tau0, m=xi, D=D, inverter=inverter)
position=tau0, m=s, inverter=inverter, w=w)
energy1 = ift.library.CriticalPowerEnergy(
position=tau1, m=xi, D=D, inverter=inverter)
position=tau1, m=s, inverter=inverter, w=w)
a = (energy1.value - energy0.value) / eps
b = energy0.gradient.vdot(direction)
tol = 1e-10
tol = 1e-4
assert_allclose(a, b, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(64, distances=.789),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment