From 9a565cd5d60770ca8d10c73b574ad637d6352458 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Sun, 15 Aug 2021 17:28:42 +0200 Subject: [PATCH] 7/n --- 2_critical_filter_solution.py | 10 +++++----- 3_more_examples.py | 8 +++----- teaser_critical_filter.py | 9 ++------- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/2_critical_filter_solution.py b/2_critical_filter_solution.py index 0fe2001..8e78e3e 100644 --- a/2_critical_filter_solution.py +++ b/2_critical_filter_solution.py @@ -63,14 +63,14 @@ H = ift.StandardHamiltonian(likelihood, ic_sampling) initial_mean = ift.MultiField.full(H.domain, 0.) mean = initial_mean -# Draw five samples and minimize KL, iterate 10 times -for _ in range(3): +# Draw five samples and minimize KL, iterate 5 times +n_rounds = 5 +for ii in range(n_rounds): KL = ift.MetricGaussianKL(mean, H, 5, True) + KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True) KL, convergence = minimizer(KL) mean = KL.position -# Draw posterior samples and plot -N_posterior_samples = 30 -KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True) +# Plot posterior samples plot_reconstruction_2d(data, ground_truth, KL, signal, R, signal.amplitude, '2_criticalfilter') diff --git a/3_more_examples.py b/3_more_examples.py index b71980f..a82507b 100644 --- a/3_more_examples.py +++ b/3_more_examples.py @@ -73,12 +73,10 @@ for mode in [0, 1]: H = ift.StandardHamiltonian(likelihood, ic_sampling) initial_mean = ift.MultiField.full(H.domain, 0.) mean = initial_mean - N_samples = 5 - for _ in range(5): + n_rounds = 5 + for ii in range(n_rounds): # Draw new samples and minimize KL - KL = ift.MetricGaussianKL(mean, H, N_samples, True) + KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True) KL, convergence = minimizer(KL) mean = KL.position - N_posterior_samples = 30 - KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True) h.plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name[mode]) diff --git a/teaser_critical_filter.py b/teaser_critical_filter.py index a0cee0e..468caeb 100644 --- a/teaser_critical_filter.py +++ b/teaser_critical_filter.py @@ -78,18 +78,13 @@ mean = initial_mean N_samples = 10 # Draw new samples to approximate the KL ten times -for i in range(10): +for _ in range(10): # Draw new samples and minimize KL KL = ift.MetricGaussianKL(mean, H, N_samples, True) KL, convergence = minimizer(KL) mean = KL.position - # FIXME Minimize posterior samples (global) -# Draw posterior samples and plotting -N_posterior_samples = 10 -KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True) - -# Plot the reconstruction result +# Plot posterior samples ground_truth = ift.makeField(position_space, ground_truth) posterior_samples = [correlated_field(KL.position+samp) for samp in KL.samples] mean = 0.*posterior_samples[0] -- GitLab