diff --git a/1_wiener_filter.py b/1_wiener_filter.py index bb194af450d4c31365011f2b97bcda8e71f2cf9b..81ea13ebc31facecebbe3af04d37e949610d3acc 100644 --- a/1_wiener_filter.py +++ b/1_wiener_filter.py @@ -1,20 +1,3 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2019 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. - import numpy as np import nifty5 as ift @@ -23,40 +6,3 @@ from helpers import generate_wf_data, plot_WF np.random.seed(42) # Want to implement: m = Dj = (S^{-1} + R^T N^{-1} R)^{-1} R^T N^{-1} d - -position_space = ift.RGSpace(256) - -prior_spectrum = lambda k: 1/(10. + k**2.5) -data, ground_truth = generate_wf_data(position_space, prior_spectrum) - -R = ift.GeometryRemover(position_space) -data_space = R.target -data = ift.from_global_data(data_space, data) - -ground_truth = ift.from_global_data(position_space, ground_truth) -plot_WF('data', ground_truth, data) - -N = ift.ScalingOperator(0.1, data_space) - -harmonic_space = position_space.get_default_codomain() -HT = ift.HartleyOperator(harmonic_space, target=position_space) - -S_h = ift.create_power_operator(harmonic_space, prior_spectrum) -S = HT @ S_h @ HT.adjoint - -D_inv = S.inverse + R.adjoint @ N.inverse @ R -j = (R.adjoint @ N.inverse)(data) - -IC = ift.GradientNormController(iteration_limit=100, tol_abs_gradnorm=1e-7) -D = ift.InversionEnabler(D_inv.inverse, IC, approximation=S) - -m = D(j) - -plot_WF('result', ground_truth, data, m) - -S = ift.SandwichOperator.make(HT.adjoint, S_h) -D = ift.WienerFilterCurvature(R, N, S, IC, IC).inverse -N_samples = 10 -samples = [D.draw_sample() + m for i in range(N_samples)] - -plot_WF('result', ground_truth, data, m=m, samples=samples) diff --git a/1_wiener_filter_solution.py b/1_wiener_filter_solution.py new file mode 100644 index 0000000000000000000000000000000000000000..bb194af450d4c31365011f2b97bcda8e71f2cf9b --- /dev/null +++ b/1_wiener_filter_solution.py @@ -0,0 +1,62 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2019 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +import numpy as np + +import nifty5 as ift +from helpers import generate_wf_data, plot_WF + +np.random.seed(42) + +# Want to implement: m = Dj = (S^{-1} + R^T N^{-1} R)^{-1} R^T N^{-1} d + +position_space = ift.RGSpace(256) + +prior_spectrum = lambda k: 1/(10. + k**2.5) +data, ground_truth = generate_wf_data(position_space, prior_spectrum) + +R = ift.GeometryRemover(position_space) +data_space = R.target +data = ift.from_global_data(data_space, data) + +ground_truth = ift.from_global_data(position_space, ground_truth) +plot_WF('data', ground_truth, data) + +N = ift.ScalingOperator(0.1, data_space) + +harmonic_space = position_space.get_default_codomain() +HT = ift.HartleyOperator(harmonic_space, target=position_space) + +S_h = ift.create_power_operator(harmonic_space, prior_spectrum) +S = HT @ S_h @ HT.adjoint + +D_inv = S.inverse + R.adjoint @ N.inverse @ R +j = (R.adjoint @ N.inverse)(data) + +IC = ift.GradientNormController(iteration_limit=100, tol_abs_gradnorm=1e-7) +D = ift.InversionEnabler(D_inv.inverse, IC, approximation=S) + +m = D(j) + +plot_WF('result', ground_truth, data, m) + +S = ift.SandwichOperator.make(HT.adjoint, S_h) +D = ift.WienerFilterCurvature(R, N, S, IC, IC).inverse +N_samples = 10 +samples = [D.draw_sample() + m for i in range(N_samples)] + +plot_WF('result', ground_truth, data, m=m, samples=samples) diff --git a/2_binary_gp_classification.py b/2_binary_gp_classification.py deleted file mode 100644 index 0d7ef9393f1360b5734ad6224f2be79d830fc2de..0000000000000000000000000000000000000000 --- a/2_binary_gp_classification.py +++ /dev/null @@ -1,89 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2019 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. - -import numpy as np - -import nifty5 as ift -from helpers import (checkerboard_response, generate_bernoulli_data, - plot_prior_samples_2d, plot_reconstruction_2d) - -np.random.seed(123) - -position_space = ift.RGSpace([256, 256]) -harmonic_space = position_space.get_default_codomain() -power_space = ift.PowerSpace(harmonic_space) - -# Build model -HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space) - -# Set up an amplitude operator for the field -# We want to set up a model for the amplitude spectrum with some magic numbers -dct = { - 'target': power_space, - 'n_pix': 64, # 64 spectral bins - # Spectral smoothness (affects Gaussian process part) - 'a': 10, # relatively high variance of spectral curvature - 'k0': .2, # quefrency mode below which cepstrum flattens - # Power-law part of spectrum: - 'sm': -4, # preferred power-law slope - 'sv': .6, # low variance of power-law slope - 'im': -3, # y-intercept mean, in-/decrease for more/less contrast - 'iv': 2. # y-intercept variance -} -A = ift.SLAmplitude(**dct) -correlated_field = ift.CorrelatedField(position_space, A) - -# Set up specific scenario -signal = correlated_field.sigmoid() - -R = checkerboard_response(position_space) -signal_response = R(signal).clip(1e-5, 1 - 1e-5) - -# Plot prior samples -plot_prior_samples_2d(5, signal, R, correlated_field, A, 'bernoulli') - -data_space = R.target -data, ground_truth = generate_bernoulli_data(signal_response) - -# Set up likelihood and information Hamiltonian -likelihood = ift.BernoulliEnergy(data)(signal_response) - -# Solve problem -# Minimization parameters -ic_sampling = ift.GradientNormController(iteration_limit=60) -ic_newton = ift.GradInfNormController( - name='Newton', tol=1e-6, iteration_limit=50) -minimizer = ift.NewtonCG(ic_newton) - -H = ift.StandardHamiltonian(likelihood, ic_sampling) -initial_mean = ift.MultiField.full(H.domain, 0.) -mean = initial_mean - -# Number of samples used to estimate the KL -N_samples = 5 - -# Draw new samples to approximate the KL five times -for _ in range(5): - # Draw new samples and minimize KL - KL = ift.MetricGaussianKL(mean, H, N_samples) - KL, convergence = minimizer(KL) - mean = KL.position - -# Plot results -N_posterior_samples = 30 -KL = ift.MetricGaussianKL(mean, H, N_posterior_samples) -plot_reconstruction_2d(data, ground_truth, KL, signal, R, A) diff --git a/2_critical_filter.py b/2_critical_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7e37bfc2dcb46b869fefc9fa4bb652070d9c7d --- /dev/null +++ b/2_critical_filter.py @@ -0,0 +1,24 @@ +import numpy as np + +import nifty5 as ift + +np.random.seed(42) + +position_space = ift.RGSpace(2*(256,)) +harmonic_space = position_space.get_default_codomain() +HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space) +power_space = ift.PowerSpace(harmonic_space) + +A = ift.SLAmplitude( + **{ + 'target': power_space, + 'n_pix': 64, # 64 spectral bins + # Smoothness of spectrum + 'a': 10, # relatively high variance of spectral curvature + 'k0': .2, # quefrency mode below which cepstrum flattens + # Power-law part of spectrum + 'sm': -4, # preferred power-law slope + 'sv': .6, # low variance of power-law slope + 'im': -2, # y-intercept mean, in-/decrease for more/less contrast + 'iv': 2. # y-intercept variance + }) diff --git a/3_critical_filter_solution.py b/2_critical_filter_solution.py similarity index 100% rename from 3_critical_filter_solution.py rename to 2_critical_filter_solution.py diff --git a/2_gauss_lognormal.py b/2_gauss_lognormal.py deleted file mode 100644 index 4dff60ac6164741066d071b001add6d6b3b4a0cb..0000000000000000000000000000000000000000 --- a/2_gauss_lognormal.py +++ /dev/null @@ -1,89 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2019 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. - -import numpy as np - -import nifty5 as ift -from helpers import (generate_gaussian_data, plot_prior_samples_2d, - plot_reconstruction_2d, radial_tomography_response) - -np.random.seed(42) - -position_space = ift.RGSpace([256, 256]) -harmonic_space = position_space.get_default_codomain() -power_space = ift.PowerSpace(harmonic_space) - -# Build model -HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space) - -# Set up an amplitude operator for the field -# We want to set up a model for the amplitude spectrum with some magic numbers -dct = { - 'target': power_space, - 'n_pix': 64, # 64 spectral bins - # Spectral smoothness (affects Gaussian process part) - 'a': 10, # relatively high variance of spectral curvature - 'k0': .2, # quefrency mode below which cepstrum flattens - # Power-law part of spectrum: - 'sm': -4, # preferred power-law slope - 'sv': .6, # low variance of power-law slope - 'im': -3, # y-intercept mean, in-/decrease for more/less contrast - 'iv': 2. # y-intercept variance -} -A = ift.SLAmplitude(**dct) -correlated_field = ift.CorrelatedField(position_space, A) - -# Set up specific scenario -signal = correlated_field.exp() - -R = radial_tomography_response(position_space, lines_of_sight=256) -signal_response = R(signal) - -data_space = R.target -N = ift.ScalingOperator(5., data_space) -data, ground_truth = generate_gaussian_data(signal_response, N) -# Set up likelihood and information Hamiltonian -likelihood = ift.GaussianEnergy(data, N)(signal_response) - -# Plot prior samples -plot_prior_samples_2d(5, signal, R, correlated_field, A, 'gauss', N=N) - -# Solve problem -# Minimization parameters -ic_sampling = ift.GradientNormController(iteration_limit=100) -ic_newton = ift.GradInfNormController( - name='Newton', tol=1e-6, iteration_limit=30) -minimizer = ift.NewtonCG(ic_newton) - -H = ift.StandardHamiltonian(likelihood, ic_sampling) -initial_mean = ift.MultiField.full(H.domain, 0.) -mean = initial_mean - -# Number of samples used to estimate the KL -N_samples = 10 - -# Draw new samples to approximate the KL five times -for _ in range(5): - # Draw new samples and minimize KL - KL = ift.MetricGaussianKL(mean, H, N_samples) - KL, convergence = minimizer(KL) - mean = KL.position - -# Plot results -N_posterior_samples = 30 -KL = ift.MetricGaussianKL(mean, H, N_posterior_samples) -plot_reconstruction_2d(data, ground_truth, KL, signal, R, A) diff --git a/2_poisson_lognormal.py b/2_poisson_lognormal.py deleted file mode 100644 index 8a282614d8be684be4465421aef13f6989158885..0000000000000000000000000000000000000000 --- a/2_poisson_lognormal.py +++ /dev/null @@ -1,88 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2019 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. - -import numpy as np - -import nifty5 as ift -from helpers import (exposure_response, generate_poisson_data, - plot_prior_samples_2d, plot_reconstruction_2d) - -np.random.seed(42) - -position_space = ift.RGSpace([256, 256]) -harmonic_space = position_space.get_default_codomain() -power_space = ift.PowerSpace(harmonic_space) - -# Build model -HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space) - -# Set up an amplitude operator for the field -# We want to set up a model for the amplitude spectrum with some magic numbers -dct = { - 'target': power_space, - 'n_pix': 64, # 64 spectral bins - # Spectral smoothness (affects Gaussian process part) - 'a': 10, # relatively high variance of spectral curvature - 'k0': .3, # quefrency mode below which cepstrum flattens - # Power-law part of spectrum: - 'sm': -3, # preferred power-law slope - 'sv': .6, # low variance of power-law slope - 'im': -4, # y-intercept mean, in-/decrease for more/less contrast - 'iv': 2. # y-intercept variance -} -A = ift.SLAmplitude(**dct) -correlated_field = ift.CorrelatedField(position_space, A) - -# Set up specific scenario -signal = correlated_field.exp() - -R = exposure_response(position_space) -signal_response = R(signal) - -data_space = R.target -data, ground_truth = generate_poisson_data(signal_response) - -likelihood = ift.PoissonianEnergy(data)(signal_response) - -# Plot prior samples -plot_prior_samples_2d(5, signal, R, correlated_field, A, 'poisson') - -# Solve problem -# Minimization parameters -ic_sampling = ift.GradientNormController(iteration_limit=100) -ic_newton = ift.GradInfNormController( - name='Newton', tol=1e-6, iteration_limit=30) -minimizer = ift.NewtonCG(ic_newton) - -H = ift.StandardHamiltonian(likelihood, ic_sampling) -initial_mean = ift.MultiField.full(H.domain, 0.) -mean = initial_mean - -# Number of samples used to estimate the KL -N_samples = 5 - -# Draw new samples to approximate the KL five times -for _ in range(5): - # Draw new samples and minimize KL - KL = ift.MetricGaussianKL(mean, H, N_samples) - KL, convergence = minimizer(KL) - mean = KL.position - -# Plot results -N_posterior_samples = 30 -KL = ift.MetricGaussianKL(mean, H, N_posterior_samples) -plot_reconstruction_2d(data, ground_truth, KL, signal, R, A) diff --git a/3_more_examples.py b/3_more_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..a843ed24914191f643cbeaed5074afccf707cb57 --- /dev/null +++ b/3_more_examples.py @@ -0,0 +1,91 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2019 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +import numpy as np + +import helpers as h +import nifty5 as ift + +seeds = [123, 42, 42] +name = ['bernoulli', 'gauss', 'poisson'] +for mode in [0, 1, 2]: + np.random.seed(seeds[mode]) + + position_space = ift.RGSpace([256, 256]) + harmonic_space = position_space.get_default_codomain() + power_space = ift.PowerSpace(harmonic_space) + HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space) + + # Set up an amplitude operator for the field + dct = { + 'target': power_space, + 'n_pix': 64, # 64 spectral bins + # Spectral smoothness (affects Gaussian process part) + 'a': 10, # relatively high variance of spectral curvature + 'k0': .2, # quefrency mode below which cepstrum flattens + # Power-law part of spectrum: + 'sm': -4, # preferred power-law slope + 'sv': .6, # low variance of power-law slope + 'im': -3, # y-intercept mean, in-/decrease for more/less contrast + 'iv': 2. # y-intercept variance + } + A = ift.SLAmplitude(**dct) + correlated_field = ift.CorrelatedField(position_space, A) + + dct = {} + if mode == 0: + signal = correlated_field.sigmoid() + R = h.checkerboard_response(position_space) + elif mode == 1: + signal = correlated_field.exp() + R = h.radial_tomography_response(position_space, lines_of_sight=256) + N = ift.ScalingOperator(5., R.target) + dct['N'] = N + elif mode == 2: + signal = correlated_field.exp() + R = h.exposure_response(position_space) + h.plot_prior_samples_2d(5, signal, R, correlated_field, A, name[mode], + **dct) + signal_response = R @ signal + if mode == 0: + signal_response = signal_response.clip(1e-5, 1 - 1e-5) + data, ground_truth = h.generate_bernoulli_data(signal_response) + likelihood = ift.BernoulliEnergy(data) @ signal_response + elif mode == 1: + data, ground_truth = h.generate_gaussian_data(signal_response, N) + likelihood = ift.GaussianEnergy(data, N) @ signal_response + elif mode == 2: + data, ground_truth = h.generate_poisson_data(signal_response) + likelihood = ift.PoissonianEnergy(data) @ signal_response + + # Solve inference problem + ic_sampling = ift.GradientNormController(iteration_limit=100) + ic_newton = ift.GradInfNormController( + name='Newton', tol=1e-6, iteration_limit=50) + minimizer = ift.NewtonCG(ic_newton) + H = ift.StandardHamiltonian(likelihood, ic_sampling) + initial_mean = ift.MultiField.full(H.domain, 0.) + mean = initial_mean + N_samples = 5 if mode in [0, 2] else 10 + for _ in range(5): + # Draw new samples and minimize KL + KL = ift.MetricGaussianKL(mean, H, N_samples) + KL, convergence = minimizer(KL) + mean = KL.position + N_posterior_samples = 30 + KL = ift.MetricGaussianKL(mean, H, N_posterior_samples) + h.plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name[mode]) diff --git a/helpers/plot.py b/helpers/plot.py index 6c9d7829e786500837153e67618578781269df0b..adc3c5ae16bfeeee0895b276ecfbfe36516a5bc1 100644 --- a/helpers/plot.py +++ b/helpers/plot.py @@ -29,15 +29,15 @@ def plot_WF(name, mock, d, m=None, samples=None): dist = mock.domain[0].distances[0] npoints = mock.domain[0].shape[0] xcoord = np.arange(npoints, dtype=np.float64)*dist - plt.plot(xcoord, d.to_global_data(), 'kx', label="data") - plt.plot(xcoord, mock.to_global_data(), 'b-', label="ground truth") + plt.plot(xcoord, d.to_global_data(), 'kx', label='data') + plt.plot(xcoord, mock.to_global_data(), 'b-', label='ground truth') if m is not None: - plt.plot(xcoord, m.to_global_data(), 'k-', label="reconstruction") + plt.plot(xcoord, m.to_global_data(), 'k-', label='reconstruction') plt.title('reconstructed signal') plt.ylabel('value') plt.xlabel('position') if samples is not None: - std = 0. + std = 0 for s in samples: std = std + (s - m)**2 std = std/len(samples) @@ -49,8 +49,8 @@ def plot_WF(name, mock, d, m=None, samples=None): md - std, md + std, alpha=0.3, - color="k", - label=r"standard deviation") + color='k', + label='standard deviation') plt.legend() x1, x2, y1, y2 = plt.axis() ymin = np.min(d.to_global_data()) - 0.1 @@ -58,30 +58,30 @@ def plot_WF(name, mock, d, m=None, samples=None): xmin = np.min(xcoord) xmax = np.max(xcoord) plt.axis((xmin, xmax, ymin, ymax)) - plt.savefig(name + ".png", dpi=300) + plt.savefig('{}.png'.format(name), dpi=300) plt.close('all') def power_plot(name, s, m, samples=None): plt.figure(figsize=(15, 8)) ks = s.domain[0].k_lengths - plt.xscale("log") - plt.yscale("log") - plt.plot(ks, s.to_global_data(), 'b-', label="ground truth") - plt.plot(ks, m.to_global_data(), 'k-', label="reconstruction") + plt.xscale('log') + plt.yscale('log') + plt.plot(ks, s.to_global_data(), 'b-', label='ground truth') + plt.plot(ks, m.to_global_data(), 'k-', label='reconstruction') plt.title('reconstructed power-spectrum') plt.ylabel('power') plt.xlabel('harmonic mode') if samples is not None: for i in range(len(samples)): if i == 0: - lgd = "samples" + lgd = 'samples' else: lgd = None plt.plot( ks, samples[i].to_global_data(), 'k-', alpha=0.3, label=lgd) plt.legend() - plt.savefig(name + ".png", dpi=300) + plt.savefig('{}.png'.format(name), dpi=300) plt.close('all') @@ -92,18 +92,13 @@ def plot_prior_samples_2d(n_samps, A, likelihood, N=None): - fig, ax = plt.subplots( - nrows=n_samps, ncols=5, figsize=( - 2*5, - 2*n_samps, - )) + fig, ax = plt.subplots(nrows=n_samps, ncols=5, figsize=(2*5, 2*n_samps)) for s in range(n_samps): sample = ift.from_random('normal', signal.domain) cf = correlated_field(sample) signal_response = R(signal) sg = signal(sample) sr = R.adjoint(R(signal(sample))) - pow = A.force(sample) if likelihood == 'gauss': data = signal_response(sample) + N.draw_sample() elif likelihood == 'poisson': @@ -118,7 +113,8 @@ def plot_prior_samples_2d(n_samps, raise ValueError('likelihood type not implemented') data = R.adjoint(data + 0.) - ax[s, 0].plot(pow.domain[0].k_lengths, pow.to_global_data()) + As = A.force(sample) + ax[s, 0].plot(As.domain[0].k_lengths, As.to_global_data()) ax[s, 0].set_yscale('log') ax[s, 0].set_xscale('log') ax[s, 0].get_xaxis().set_visible(False) @@ -148,28 +144,21 @@ def plot_prior_samples_2d(n_samps, ax[0, 4].set_title('synthetic data') ax[n_samps - 1, 0].get_xaxis().set_visible(True) - plt.tight_layout() - plt.savefig('prior_samples_' + likelihood + '.png') + plt.savefig('prior_samples_{}.png'.format(likelihood)) plt.close('all') -def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A): +def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name): sc = ift.StatCalculator() - sky_samples = [] - amp_samples = [] + sky_samples, amp_samples = [], [] for sample in KL.samples: tmp = signal(sample + KL.position) - pow = A.force(sample) sc.add(tmp) - sky_samples += [tmp] - amp_samples += [pow] - - fig, ax = plt.subplots( - nrows=2, ncols=3, figsize=( - 4*3, - 4*2, - )) + sky_samples.append(tmp) + amp_samples.append(A.force(sample)) + + fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(4*3, 4*2)) im = list() im.append(ax[0, 0].imshow( signal(ground_truth).to_global_data(), aspect='auto')) @@ -193,7 +182,7 @@ def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A): ax[1, 2].plot( s.domain[0].k_lengths, s.to_global_data(), color='lightgrey') - amp_mean = sum(amp_samples)/(len(amp_samples)) + amp_mean = sum(amp_samples)/len(amp_samples) ax[1, 2].plot( amp_mean.domain[0].k_lengths, amp_mean.to_global_data(), @@ -209,19 +198,16 @@ def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A): ax[1, 2].set_xscale('log') ax[1, 2].set_title('power-spectra') - c = 0 - for i, j in product(range(2), range(3)): - if not (i == 1 and j == 2): + for c, i, j in enumerate(product(range(2), range(3))): + if i != 1 or j != 2: ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) divider = make_axes_locatable(ax[i, j]) - ax_cb = divider.new_horizontal(size="5%", pad=0.05) + ax_cb = divider.new_horizontal(size='5%', pad=0.05) fig1 = ax[i, j].get_figure() fig1.add_axes(ax_cb) ax[i, j].figure.colorbar( im[c], ax=ax[i, j], cax=ax_cb, orientation='vertical') - c += 1 - plt.tight_layout() - plt.savefig('reconstruction.png') + plt.savefig('reconstruction{}.png'.format(name)) plt.close('all')