Skip to content
Snippets Groups Projects
Commit 9a565cd5 authored by Philipp Arras's avatar Philipp Arras
Browse files

7/n

parent 880f85dd
Branches
No related tags found
1 merge request!2Draft: Nifty5 to nifty7
...@@ -63,14 +63,14 @@ H = ift.StandardHamiltonian(likelihood, ic_sampling) ...@@ -63,14 +63,14 @@ H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_mean = ift.MultiField.full(H.domain, 0.) initial_mean = ift.MultiField.full(H.domain, 0.)
mean = initial_mean mean = initial_mean
# Draw five samples and minimize KL, iterate 10 times # Draw five samples and minimize KL, iterate 5 times
for _ in range(3): n_rounds = 5
for ii in range(n_rounds):
KL = ift.MetricGaussianKL(mean, H, 5, True) KL = ift.MetricGaussianKL(mean, H, 5, True)
KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
mean = KL.position mean = KL.position
# Draw posterior samples and plot # Plot posterior samples
N_posterior_samples = 30
KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True)
plot_reconstruction_2d(data, ground_truth, KL, signal, R, signal.amplitude, plot_reconstruction_2d(data, ground_truth, KL, signal, R, signal.amplitude,
'2_criticalfilter') '2_criticalfilter')
...@@ -73,12 +73,10 @@ for mode in [0, 1]: ...@@ -73,12 +73,10 @@ for mode in [0, 1]:
H = ift.StandardHamiltonian(likelihood, ic_sampling) H = ift.StandardHamiltonian(likelihood, ic_sampling)
initial_mean = ift.MultiField.full(H.domain, 0.) initial_mean = ift.MultiField.full(H.domain, 0.)
mean = initial_mean mean = initial_mean
N_samples = 5 n_rounds = 5
for _ in range(5): for ii in range(n_rounds):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples, True) KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
mean = KL.position mean = KL.position
N_posterior_samples = 30
KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True)
h.plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name[mode]) h.plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name[mode])
...@@ -78,18 +78,13 @@ mean = initial_mean ...@@ -78,18 +78,13 @@ mean = initial_mean
N_samples = 10 N_samples = 10
# Draw new samples to approximate the KL ten times # Draw new samples to approximate the KL ten times
for i in range(10): for _ in range(10):
# Draw new samples and minimize KL # Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples, True) KL = ift.MetricGaussianKL(mean, H, N_samples, True)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
mean = KL.position mean = KL.position
# FIXME Minimize posterior samples (global)
# Draw posterior samples and plotting # Plot posterior samples
N_posterior_samples = 10
KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True)
# Plot the reconstruction result
ground_truth = ift.makeField(position_space, ground_truth) ground_truth = ift.makeField(position_space, ground_truth)
posterior_samples = [correlated_field(KL.position+samp) for samp in KL.samples] posterior_samples = [correlated_field(KL.position+samp) for samp in KL.samples]
mean = 0.*posterior_samples[0] mean = 0.*posterior_samples[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment