Commit 9a93746b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'try_different_sampling' into 'NIFTy_4'

Try different sampling

See merge request ift/NIFTy!245
parents 5884a0eb 6a6efcd9
Pipeline #27816 passed with stages
in 20 minutes and 52 seconds
import nifty4 as ift
import numpy as np
import matplotlib.pyplot as plt
from nifty4.sugar import create_power_operator
np.random.seed(42)
x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()
d_space = x_space
N_hat = ift.Field(d_space, 10.)
N_hat.val[400:450] = 0.0001
N = ift.DiagonalOperator(N_hat, d_space)
FFT = ift.HarmonicTransformOperator(h_space, target=x_space)
R = ift.ScalingOperator(1., x_space)
def ampspec(k): return 1. / (1. + k**2.)
S = ift.ScalingOperator(1., h_space)
A = create_power_operator(h_space, ampspec)
s_h = S.draw_sample()
sky = FFT * A
s_x = sky(s_h)
n = N.draw_sample()
d = R(s_x) + n
R_p = R * FFT * A
j = R_p.adjoint(N.inverse(d))
D_inv = R_p.adjoint * N.inverse * R_p + S.inverse
history = []
def sample(D_inv, S, j, N_samps, N_iter):
global history
space = D_inv.domain
x = ift.Field.zeros(space)
r = j.copy()
p = r.copy()
d = p.vdot(D_inv(p))
y = []
for i in range(N_samps):
y += [S.draw_sample()]
for k in range(1, 1 + N_iter):
history += [y[0].copy()]
gamma = r.vdot(r) / d
if gamma == 0.:
break
x += gamma * p
#print(p.vdot(D_inv(j)))
for i in range(N_samps):
y[i] -= p.vdot(D_inv(y[i])) * p / d
y[i] += np.random.randn() / np.sqrt(d) * p
print("variance iteration "+str(k)+":", np.sqrt(p.vdot(p)/d))
#r_new = j - D_inv(x)
r_new = r - gamma * D_inv(p)
beta = r_new.vdot(r_new) / (r.vdot(r))
r = r_new
p = r + beta * p
d = p.vdot(D_inv(p))
if d == 0.:
break
return x, y
N_samps = 200
N_iter = 10
m, samps = sample(D_inv, S, j, N_samps, N_iter)
m_x = sky(m)
IC = ift.GradientNormController(iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
samps_old = []
for i in range(N_samps):
samps_old += [curv.draw_sample(from_inverse=True)]
plt.plot(d.val, '+', label="data", alpha=.5)
plt.plot(s_x.val, label="original")
plt.plot(m_x.val, label="reconstruction")
plt.legend()
plt.savefig('Krylov_reconstruction.png')
plt.close()
pltdict = {'alpha': .3, 'linewidth': .2}
for i in range(N_samps):
if i == 0:
plt.plot(sky(samps_old[i]).val, color='b',
label='Traditional samples (residuals)',
**pltdict)
plt.plot(sky(samps[i]).val, color='r',
label='Krylov samples (residuals)',
**pltdict)
else:
plt.plot(sky(samps_old[i]).val, color='b', **pltdict)
plt.plot(sky(samps[i]).val, color='r', **pltdict)
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_samples_residuals.png')
plt.close()
D_hat_old = ift.Field.zeros(x_space).val
D_hat_new = ift.Field.zeros(x_space).val
for i in range(N_samps):
D_hat_old += sky(samps_old[i]).val**2
D_hat_new += sky(samps[i]).val**2
plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
label='Krylov uncertainty')
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_uncertainty.png')
plt.close()
for i in range(min(6, len(history))):
plt.plot(sky(history[i]).val, label="step " + str(i+1))
plt.plot(s_x.val-m_x.val, 'k-', label="residual")
plt.legend()
plt.savefig('iterations.png')
plt.close()
......@@ -5,3 +5,4 @@ from .nonlinear_power_energy import NonlinearPowerEnergy
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .poisson_energy import PoissonEnergy
from .nonlinearities import Exponential, Linear, Tanh, PositiveTanh
from .krylov_sampling import generate_krylov_samples
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from numpy import sqrt
from numpy.random import randn
def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
name=None):
"""
Generates inverse samples from a curvature D.
This algorithm iteratively generates samples from
a curvature D by applying conjugate gradient steps
and resampling the curvature in search direction.
Parameters
----------
D_inv : EndomorphicOperator
The curvature which will be the inverse of the covarianc
of the generated samples
S : EndomorphicOperator (from which one can sample)
A prior covariance operator which is used to generate prior
samples that are then iteratively updated
j : Field, optional
A Field to which the inverse of D_inv is applied. The solution
of this matrix inversion problem is a side product of generating
the samples.
If not supplied, it is sampled from the inverse prior.
N_samps : Int, optional
How many samples to generate. Default: 1
N_iter : Int, optional
How many iterations of the conjugate gradient to run. Default: 10
Returns
-------
(solution, samples) : A tuple of a field 'solution' and a list of fields
'samples'. The first entry of the tuple is the solution x to
D_inv(x) = j
and the second entry are a list of samples from D_inv.inverse
"""
j = S.draw_sample(from_inverse=True) if j is None else j
x = S.draw_sample()
r = j.copy()
p = r.copy()
d = p.vdot(D_inv(p))
y = [S.draw_sample() for _ in range(N_samps)]
for k in range(1, 1+N_iter):
gamma = r.vdot(r)/d
if gamma == 0.:
break
x += gamma*p
for i in range(N_samps):
y[i] -= p.vdot(D_inv(y[i])) * p / d
y[i] += randn() / sqrt(d) * p
r_new = r - gamma * D_inv(p)
beta = r_new.vdot(r_new) / r.vdot(r)
r = r_new
p = r + beta * p
d = p.vdot(D_inv(p))
if d == 0.:
break
if name is not None:
print('{}: Iteration #{}'.format(name, k))
return x, y
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment