diff --git a/2_critical_filter_solution.py b/2_critical_filter_solution.py index 0fe20015649a8f58621ed34d0e852b75106eae42..8e78e3eae010a1aeea331995e27a53e94d877a22 100644 --- a/2_critical_filter_solution.py +++ b/2_critical_filter_solution.py @@ -63,14 +63,14 @@ H = ift.StandardHamiltonian(likelihood, ic_sampling) initial_mean = ift.MultiField.full(H.domain, 0.) mean = initial_mean -# Draw five samples and minimize KL, iterate 10 times -for _ in range(3): +# Draw five samples and minimize KL, iterate 5 times +n_rounds = 5 +for ii in range(n_rounds): KL = ift.MetricGaussianKL(mean, H, 5, True) + KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True) KL, convergence = minimizer(KL) mean = KL.position -# Draw posterior samples and plot -N_posterior_samples = 30 -KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True) +# Plot posterior samples plot_reconstruction_2d(data, ground_truth, KL, signal, R, signal.amplitude, '2_criticalfilter') diff --git a/3_more_examples.py b/3_more_examples.py index b71980ffcb8b0540d398d7a5686126f5e8a70d1c..a82507b71a9175dce8aed7a807548c5b0c52a56f 100644 --- a/3_more_examples.py +++ b/3_more_examples.py @@ -73,12 +73,10 @@ for mode in [0, 1]: H = ift.StandardHamiltonian(likelihood, ic_sampling) initial_mean = ift.MultiField.full(H.domain, 0.) mean = initial_mean - N_samples = 5 - for _ in range(5): + n_rounds = 5 + for ii in range(n_rounds): # Draw new samples and minimize KL - KL = ift.MetricGaussianKL(mean, H, N_samples, True) + KL = ift.MetricGaussianKL(mean, H, 30 if ii == n_rounds-1 else 5, True) KL, convergence = minimizer(KL) mean = KL.position - N_posterior_samples = 30 - KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True) h.plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name[mode]) diff --git a/teaser_critical_filter.py b/teaser_critical_filter.py index a0cee0ee7f0add374974aa981f484addf170ad99..468caebf8eb9808adf365679b81ab0af711ac6b9 100644 --- a/teaser_critical_filter.py +++ b/teaser_critical_filter.py @@ -78,18 +78,13 @@ mean = initial_mean N_samples = 10 # Draw new samples to approximate the KL ten times -for i in range(10): +for _ in range(10): # Draw new samples and minimize KL KL = ift.MetricGaussianKL(mean, H, N_samples, True) KL, convergence = minimizer(KL) mean = KL.position - # FIXME Minimize posterior samples (global) -# Draw posterior samples and plotting -N_posterior_samples = 10 -KL = ift.MetricGaussianKL(mean, H, N_posterior_samples, True) - -# Plot the reconstruction result +# Plot posterior samples ground_truth = ift.makeField(position_space, ground_truth) posterior_samples = [correlated_field(KL.position+samp) for samp in KL.samples] mean = 0.*posterior_samples[0]