-
Philipp Arras authoredPhilipp Arras authored
plot.py 7.44 KiB
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from itertools import product
import numpy as np
import pylab as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import nifty7 as ift
def plot_WF(name, mock, d, m=None, samples=None):
plt.figure(figsize=(15, 8))
dist = mock.domain[0].distances[0]
npoints = mock.domain[0].shape[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist
plt.plot(xcoord, d.val, 'kx', label='data')
plt.plot(xcoord, mock.val, 'b-', label='ground truth')
if m is not None:
plt.plot(xcoord, m.val, 'k-', label='reconstruction')
plt.title('reconstructed signal')
plt.ylabel('value')
plt.xlabel('position')
if samples is not None:
std = 0
for s in samples:
std = std + (s - m)**2
std = std/len(samples)
std = std.val
std = np.sqrt(std)
md = m.val
plt.fill_between(
xcoord,
md - std,
md + std,
alpha=0.3,
color='k',
label='standard deviation')
plt.legend()
x1, x2, y1, y2 = plt.axis()
ymin = np.min(d.val) - 0.1
ymax = np.max(d.val) + 0.1
xmin = np.min(xcoord)
xmax = np.max(xcoord)
plt.axis((xmin, xmax, ymin, ymax))
plt.savefig('{}.png'.format(name), dpi=300)
plt.close('all')
def power_plot(name, s, m, samples=None):
plt.figure(figsize=(15, 8))
ks = s.domain[0].k_lengths
plt.xscale('log')
plt.yscale('log')
plt.plot(ks, s.val, 'b-', label='ground truth')
plt.plot(ks, m.val, 'k-', label='reconstruction')
plt.title('reconstructed power-spectrum')
plt.ylabel('power')
plt.xlabel('harmonic mode')
if samples is not None:
for i in range(len(samples)):
if i == 0:
lgd = 'samples'
else:
lgd = None
plt.plot(
ks, samples[i].val, 'k-', alpha=0.3, label=lgd)
plt.legend()
plt.savefig('{}.png'.format(name), dpi=300)
plt.close('all')
def plot_prior_samples_2d(n_samps, signal, R, correlated_field, pspec,
likelihood, N=None):
samples, pspecmin, pspecmax = [], np.inf, 0
for _ in range(n_samps):
ss = ift.from_random(signal.domain)
samples.append(ss)
foo = pspec.force(ss).val
print(pspecmin, pspecmax)
pspecmin = min([min(foo), pspecmin])
pspecmax = max([max(foo), pspecmin])
pspecmin /= 10
pspecmax *= 10
fig, ax = plt.subplots(nrows=n_samps, ncols=5, figsize=(2*5, 2*n_samps))
for ii, sample in enumerate(samples):
cf = correlated_field(sample)
signal_response = R @ signal
sg = signal(sample)
sr = (R.adjoint @ R @ signal)(sample)
if likelihood == 'gauss':
data = signal_response(sample) + N.draw_sample_with_dtype(np.float64)
elif likelihood == 'poisson':
rate = signal_response(sample).val
data = ift.makeField(signal_response.target, np.random.poisson(rate))
elif likelihood == 'bernoulli':
rate = signal_response(sample).val
data = ift.makeField(signal_response.target,
np.random.binomial(1, rate))
else:
raise ValueError('likelihood type not implemented')
data = R.adjoint(data + 0.)
ax[ii, 0].plot(pspec.target[0].k_lengths, pspec.force(sample))
ax[ii, 0].set_yscale('log')
ax[ii, 0].set_xscale('log')
ax[ii, 0].set_ylim(pspecmin, pspecmax)
ax[ii, 0].get_xaxis().set_visible(False)
ax[ii, 1].imshow(cf.val, aspect='auto')
ax[ii, 1].get_xaxis().set_visible(False)
ax[ii, 1].get_yaxis().set_visible(False)
ax[ii, 2].imshow(sg.val, aspect='auto')
ax[ii, 2].get_xaxis().set_visible(False)
ax[ii, 2].get_yaxis().set_visible(False)
ax[ii, 3].imshow(sr.val, aspect='auto')
ax[ii, 3].get_xaxis().set_visible(False)
ax[ii, 3].get_yaxis().set_visible(False)
ax[ii, 4].imshow(data.val, cmap='viridis', aspect='auto')
ax[ii, 4].get_xaxis().set_visible(False)
ax[ii, 4].yaxis.tick_right()
ax[ii, 4].get_yaxis().set_visible(False)
if ii == 0:
ax[0, 0].set_title('power spectrum')
ax[0, 1].set_title('correlated field')
ax[0, 2].set_title('signal')
ax[0, 3].set_title('signal Response')
ax[0, 4].set_title('synthetic data')
ax[n_samps - 1, 0].get_xaxis().set_visible(True)
plt.tight_layout()
plt.savefig('prior_samples_{}.png'.format(likelihood))
plt.close('all')
def plot_reconstruction_2d(data, ground_truth, KL, signal, R, pspec, name):
sc = ift.StatCalculator()
sky_samples, pspec_samples = [], []
for sample in KL.samples:
tmp = signal(sample + KL.position)
sc.add(tmp)
sky_samples.append(tmp)
pspec_samples.append(pspec.force(sample))
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(4*3, 4*2))
im = []
foo = signal(ground_truth).val
vmin, vmax = np.min(foo), np.max(foo)
im.append(ax[0, 0].imshow(foo, aspect='auto', vmin=vmin, vmax=vmax))
ax[0, 0].set_title('true signal')
im.append(ax[0, 1].imshow(
R.adjoint(R(sc.mean)).val, aspect='auto'))
ax[0, 1].set_title('signal response')
im.append(ax[0, 2].imshow(R.adjoint(data).val, aspect='auto'))
ax[0, 2].set_title('data')
im.append(ax[1, 0].imshow(
sc.mean.val, aspect='auto', vmin=vmin, vmax=vmax))
ax[1, 0].set_title('posterior mean')
im.append(ax[1, 1].imshow(
ift.sqrt(sc.var).val, aspect='auto'))
ax[1, 1].set_title('standard deviation')
for ss in pspec_samples:
ax[1, 2].plot(
ss.domain[0].k_lengths, ss.val, color='lightgrey')
amp_mean = sum(pspec_samples)/len(pspec_samples)
ax[1, 2].plot(
amp_mean.domain[0].k_lengths,
amp_mean.val,
color='black',
label='reconstruction')
ax[1, 2].plot(
amp_mean.domain[0].k_lengths,
pspec.force(ground_truth).val,
color='b',
label='ground truth')
ax[1, 2].legend()
ax[1, 2].set_yscale('log')
ax[1, 2].set_xscale('log')
ax[1, 2].set_title('power spectra')
for c, (i, j) in enumerate(product(range(2), range(3))):
if i != 1 or j != 2:
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
divider = make_axes_locatable(ax[i, j])
ax_cb = divider.new_horizontal(size='5%', pad=0.05)
fig1 = ax[i, j].get_figure()
fig1.add_axes(ax_cb)
ax[i, j].figure.colorbar(
im[c], ax=ax[i, j], cax=ax_cb, orientation='vertical')
plt.tight_layout()
plt.savefig('reconstruction{}.png'.format(name))
plt.close('all')