From 9a565cd5d60770ca8d10c73b574ad637d6352458 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Sun, 15 Aug 2021 17:28:42 +0200
Subject: [PATCH] 7/n

---
 2_critical_filter_solution.py | 10 +++++-----
 3_more_examples.py            |  8 +++-----
 teaser_critical_filter.py     |  9 ++-------
 3 files changed, 10 insertions(+), 17 deletions(-)

diff --git a/2_critical_filter_solution.py b/2_critical_filter_solution.py
index 0fe2001..8e78e3e 100644
--- a/2_critical_filter_solution.py
+++ b/2_critical_filter_solution.py
@@ -63,14 +63,14 @@ H = ift.StandardHamiltonian(likelihood, ic_sampling)
 initial_mean = ift.MultiField.full(H.domain, 0.)
 mean = initial_mean
 
-# Draw five samples and minimize KL, iterate 10 times
-for _ in range(3):
+# Draw five samples and minimize KL, iterate 5 times
+n_rounds = 5
+for ii in range(n_rounds):
     KL = ift.MetricGaussianKL(mean, H, 5, True)
+    KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True)
     KL, convergence = minimizer(KL)
     mean = KL.position
 
-# Draw posterior samples and plot
-N_posterior_samples = 30
-KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True)
+# Plot posterior samples
 plot_reconstruction_2d(data, ground_truth, KL, signal, R, signal.amplitude,
                        '2_criticalfilter')
diff --git a/3_more_examples.py b/3_more_examples.py
index b71980f..a82507b 100644
--- a/3_more_examples.py
+++ b/3_more_examples.py
@@ -73,12 +73,10 @@ for mode in [0, 1]:
     H = ift.StandardHamiltonian(likelihood, ic_sampling)
     initial_mean = ift.MultiField.full(H.domain, 0.)
     mean = initial_mean
-    N_samples = 5
-    for _ in range(5):
+    n_rounds = 5
+    for ii in range(n_rounds):
         # Draw new samples and minimize KL
-        KL = ift.MetricGaussianKL(mean, H, N_samples, True)
+        KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True)
         KL, convergence = minimizer(KL)
         mean = KL.position
-    N_posterior_samples = 30
-    KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True)
     h.plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name[mode])
diff --git a/teaser_critical_filter.py b/teaser_critical_filter.py
index a0cee0e..468caeb 100644
--- a/teaser_critical_filter.py
+++ b/teaser_critical_filter.py
@@ -78,18 +78,13 @@ mean = initial_mean
 N_samples = 10
 
 # Draw new samples to approximate the KL ten times
-for i in range(10):
+for _ in range(10):
     # Draw new samples and minimize KL
     KL = ift.MetricGaussianKL(mean, H, N_samples, True)
     KL, convergence = minimizer(KL)
     mean = KL.position
-    # FIXME Minimize posterior samples (global)
 
-# Draw posterior samples and plotting
-N_posterior_samples = 10
-KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True)
-
-# Plot the reconstruction result
+# Plot posterior samples
 ground_truth = ift.makeField(position_space, ground_truth)
 posterior_samples = [correlated_field(KL.position+samp) for samp in KL.samples]
 mean = 0.*posterior_samples[0]
-- 
GitLab