Commit fb3d4658 authored by Philipp Arras's avatar Philipp Arras

Use sample generation of energies

parent b6d9b00a
......@@ -82,15 +82,11 @@ class Noise_Energy_Tests(unittest.TestCase):
N=N,
S=S,
inverter=inverter).curvature
Nsamples = 10
xi_sample_list = [
D.generate_posterior_sample() +
xi for i in range(Nsamples)]
energy0 = ift.library.NoiseEnergy(
position=eta0, d=d, xi=xi, D=D, t=tau, Instrument=R,
alpha=alpha, q=q, Projection=P, nonlinearity=f,
ht=ht, xi_sample_list=xi_sample_list)
ht=ht, samples=10)
energy1 = energy0.at(eta1)
a = (energy1.value - energy0.value) / eps
......
......@@ -69,14 +69,9 @@ class Energy_Tests(unittest.TestCase):
inverter=inverter).curvature
w = ift.Field.zeros_like(tau0)
Nsamples = 10
for i in range(Nsamples):
sample = D.generate_posterior_sample() + s
w += P(abs(sample)**2)
w /= Nsamples
energy0 = ift.library.CriticalPowerEnergy(
position=tau0, m=s, inverter=inverter, w=w)
position=tau0, m=s, inverter=inverter, w=w, samples=10)
energy1 = energy0.at(tau1)
a = (energy1.value - energy0.value) / eps
......@@ -130,10 +125,6 @@ class Energy_Tests(unittest.TestCase):
S=S,
ht=ht,
inverter=inverter).curvature
Nsamples = 10
xi_sample_list = [
D.generate_posterior_sample() +
xi for _ in range(Nsamples)]
energy0 = ift.library.NonlinearPowerEnergy(
position=tau0,
......@@ -145,7 +136,7 @@ class Energy_Tests(unittest.TestCase):
nonlinearity=f,
ht=ht,
N=N,
xi_sample_list=xi_sample_list)
samples=10)
energy1 = energy0.at(tau1)
a = (energy1.value - energy0.value) / eps
......@@ -196,15 +187,8 @@ class Curvature_Tests(unittest.TestCase):
D = ift.library.WienerFilterEnergy(position=s, d=d, R=R, N=N, S=S,
inverter=inverter).curvature
w = ift.Field.zeros_like(tau0)
Nsamples = 10
for i in range(Nsamples):
sample = D.generate_posterior_sample() + s
w += P(abs(sample)**2)
w /= Nsamples
energy0 = ift.library.CriticalPowerEnergy(
position=tau0, m=s, inverter=inverter, w=w)
position=tau0, m=s, inverter=inverter, samples=10)
gradient0 = energy0.gradient
gradient1 = energy0.at(tau1).gradient
......@@ -261,10 +245,6 @@ class Curvature_Tests(unittest.TestCase):
S=S,
ht=ht,
inverter=inverter).curvature
Nsamples = 10
xi_sample_list = [
D.generate_posterior_sample() +
xi for _ in range(Nsamples)]
energy0 = ift.library.NonlinearPowerEnergy(
position=tau0,
......@@ -276,7 +256,7 @@ class Curvature_Tests(unittest.TestCase):
nonlinearity=f,
ht=ht,
N=N,
xi_sample_list=xi_sample_list)
samples=10)
gradient0 = energy0.gradient
gradient1 = energy0.at(tau1).gradient
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment