Skip to content
Snippets Groups Projects
Commit a075bd9a authored by Philipp Arras's avatar Philipp Arras
Browse files

Cosmetics and actually plot power spectrum not amplitude

parent 153dfeb8
Branches
No related tags found
No related merge requests found
......@@ -92,13 +92,21 @@ def plot_prior_samples_2d(n_samps,
A,
likelihood,
N=None):
samples, pspecmin, pspecmax = [], np.inf, 0
pspec = A*A
for _ in range(n_samps):
ss = ift.from_random('normal', signal.domain)
samples.append(ss)
foo = pspec.force(ss).to_global_data()
pspecmin = min([min(foo), pspecmin])
pspecmax = max([max(foo), pspecmin])
fig, ax = plt.subplots(nrows=n_samps, ncols=5, figsize=(2*5, 2*n_samps))
for s in range(n_samps):
sample = ift.from_random('normal', signal.domain)
for ii, sample in enumerate(samples):
cf = correlated_field(sample)
signal_response = R(signal)
signal_response = R @ signal
sg = signal(sample)
sr = R.adjoint(R(signal(sample)))
sr = (R.adjoint @ R @ signal)(sample)
if likelihood == 'gauss':
data = signal_response(sample) + N.draw_sample()
elif likelihood == 'poisson':
......@@ -113,31 +121,32 @@ def plot_prior_samples_2d(n_samps,
raise ValueError('likelihood type not implemented')
data = R.adjoint(data + 0.)
As = A.force(sample)
ax[s, 0].plot(As.domain[0].k_lengths, As.to_global_data())
ax[s, 0].set_yscale('log')
ax[s, 0].set_xscale('log')
ax[s, 0].get_xaxis().set_visible(False)
As = pspec.force(sample)
ax[ii, 0].plot(As.domain[0].k_lengths, As.to_global_data())
ax[ii, 0].set_ylim(pspecmin, pspecmax)
ax[ii, 0].set_yscale('log')
ax[ii, 0].set_xscale('log')
ax[ii, 0].get_xaxis().set_visible(False)
ax[s, 1].imshow(cf.to_global_data(), aspect='auto')
ax[s, 1].get_xaxis().set_visible(False)
ax[s, 1].get_yaxis().set_visible(False)
ax[ii, 1].imshow(cf.to_global_data(), aspect='auto')
ax[ii, 1].get_xaxis().set_visible(False)
ax[ii, 1].get_yaxis().set_visible(False)
ax[s, 2].imshow(sg.to_global_data(), aspect='auto')
ax[s, 2].get_xaxis().set_visible(False)
ax[s, 2].get_yaxis().set_visible(False)
ax[ii, 2].imshow(sg.to_global_data(), aspect='auto')
ax[ii, 2].get_xaxis().set_visible(False)
ax[ii, 2].get_yaxis().set_visible(False)
ax[s, 3].imshow(sr.to_global_data(), aspect='auto')
ax[s, 3].get_xaxis().set_visible(False)
ax[s, 3].get_yaxis().set_visible(False)
ax[ii, 3].imshow(sr.to_global_data(), aspect='auto')
ax[ii, 3].get_xaxis().set_visible(False)
ax[ii, 3].get_yaxis().set_visible(False)
ax[s, 4].imshow(data.to_global_data(), cmap='viridis', aspect='auto')
ax[s, 4].get_xaxis().set_visible(False)
ax[s, 4].yaxis.tick_right()
ax[s, 4].get_yaxis().set_visible(False)
ax[ii, 4].imshow(data.to_global_data(), cmap='viridis', aspect='auto')
ax[ii, 4].get_xaxis().set_visible(False)
ax[ii, 4].yaxis.tick_right()
ax[ii, 4].get_yaxis().set_visible(False)
if s == 0:
ax[0, 0].set_title('power-spectrum')
if ii == 0:
ax[0, 0].set_title('power spectrum')
ax[0, 1].set_title('correlated field')
ax[0, 2].set_title('signal')
ax[0, 3].set_title('signal Response')
......@@ -151,15 +160,15 @@ def plot_prior_samples_2d(n_samps,
def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name):
sc = ift.StatCalculator()
sky_samples, amp_samples = [], []
sky_samples, pspec_samples = [], []
for sample in KL.samples:
tmp = signal(sample + KL.position)
sc.add(tmp)
sky_samples.append(tmp)
amp_samples.append(A.force(sample))
pspec_samples.append(A.force(sample)**2)
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(4*3, 4*2))
im = list()
im = []
im.append(ax[0, 0].imshow(
signal(ground_truth).to_global_data(), aspect='auto'))
ax[0, 0].set_title('true signal')
......@@ -178,11 +187,11 @@ def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name):
ift.sqrt(sc.var).to_global_data(), aspect='auto'))
ax[1, 1].set_title('standard deviation')
for s in amp_samples:
for ss in pspec_samples:
ax[1, 2].plot(
s.domain[0].k_lengths, s.to_global_data(), color='lightgrey')
ss.domain[0].k_lengths, ss.to_global_data(), color='lightgrey')
amp_mean = sum(amp_samples)/len(amp_samples)
amp_mean = sum(pspec_samples)/len(pspec_samples)
ax[1, 2].plot(
amp_mean.domain[0].k_lengths,
amp_mean.to_global_data(),
......@@ -196,7 +205,7 @@ def plot_reconstruction_2d(data, ground_truth, KL, signal, R, A, name):
ax[1, 2].legend()
ax[1, 2].set_yscale('log')
ax[1, 2].set_xscale('log')
ax[1, 2].set_title('power-spectra')
ax[1, 2].set_title('power spectra')
for c, i, j in enumerate(product(range(2), range(3))):
if i != 1 or j != 2:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment