Skip to content
Snippets Groups Projects
Commit a84f6bff authored by Philipp Arras's avatar Philipp Arras
Browse files

Clean up more examples

parent b8afe00c
No related branches found
No related tags found
No related merge requests found
...@@ -20,9 +20,9 @@ import numpy as np ...@@ -20,9 +20,9 @@ import numpy as np
import helpers as h import helpers as h
import nifty5 as ift import nifty5 as ift
seeds = [123, 42, 42] seeds = [123, 42]
name = ['bernoulli', 'gauss', 'poisson'] name = ['bernoulli', 'poisson']
for mode in [0, 1, 2]: for mode in [0, 1]:
np.random.seed(seeds[mode]) np.random.seed(seeds[mode])
position_space = ift.RGSpace([256, 256]) position_space = ift.RGSpace([256, 256])
...@@ -50,12 +50,7 @@ for mode in [0, 1, 2]: ...@@ -50,12 +50,7 @@ for mode in [0, 1, 2]:
if mode == 0: if mode == 0:
signal = correlated_field.sigmoid() signal = correlated_field.sigmoid()
R = h.checkerboard_response(position_space) R = h.checkerboard_response(position_space)
elif mode == 1: else:
signal = correlated_field.exp()
R = h.radial_tomography_response(position_space, lines_of_sight=256)
N = ift.ScalingOperator(5., R.target)
dct['N'] = N
elif mode == 2:
signal = correlated_field.exp() signal = correlated_field.exp()
R = h.exposure_response(position_space) R = h.exposure_response(position_space)
h.plot_prior_samples_2d(5, signal, R, correlated_field, A, name[mode], h.plot_prior_samples_2d(5, signal, R, correlated_field, A, name[mode],
...@@ -65,10 +60,7 @@ for mode in [0, 1, 2]: ...@@ -65,10 +60,7 @@ for mode in [0, 1, 2]:
signal_response = signal_response.clip(1e-5, 1 - 1e-5) signal_response = signal_response.clip(1e-5, 1 - 1e-5)
data, ground_truth = h.generate_bernoulli_data(signal_response) data, ground_truth = h.generate_bernoulli_data(signal_response)
likelihood = ift.BernoulliEnergy(data) @ signal_response likelihood = ift.BernoulliEnergy(data) @ signal_response
elif mode == 1: else:
data, ground_truth = h.generate_gaussian_data(signal_response, N)
likelihood = ift.GaussianEnergy(data, N) @ signal_response
elif mode == 2:
data, ground_truth = h.generate_poisson_data(signal_response) data, ground_truth = h.generate_poisson_data(signal_response)
likelihood = ift.PoissonianEnergy(data) @ signal_response likelihood = ift.PoissonianEnergy(data) @ signal_response
...@@ -80,7 +72,7 @@ for mode in [0, 1, 2]: ...@@ -80,7 +72,7 @@ for mode in [0, 1, 2]:
H = ift.StandardHamiltonian(likelihood, ic_sampling) H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_mean = ift.MultiField.full(H.domain, 0.) initial_mean = ift.MultiField.full(H.domain, 0.)
mean = initial_mean mean = initial_mean
N_samples = 5 if mode in [0, 2] else 10 N_samples = 5
for _ in range(5): for _ in range(5):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples) KL = ift.MetricGaussianKL(mean, H, N_samples)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment