Commit 208763a9 authored by Philipp Arras's avatar Philipp Arras
Browse files

Remove Krylov sampling

parent 03823a8d
import nifty4 as ift
import numpy as np
import matplotlib.pyplot as plt
from nifty4.sugar import create_power_operator
np.random.seed(42)
x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()
d_space = x_space
N_hat = np.full(d_space.shape, 10.)
N_hat[400:450] = 0.0001
N_hat = ift.Field.from_global_data(d_space, N_hat)
N = ift.DiagonalOperator(N_hat)
FFT = ift.HarmonicTransformOperator(h_space, x_space)
R = ift.ScalingOperator(1., x_space)
def ampspec(k): return 1. / (1. + k**2.)
S = ift.ScalingOperator(1., h_space)
A = create_power_operator(h_space, ampspec)
s_h = S.draw_sample()
sky = FFT * A
s_x = sky(s_h)
n = N.draw_sample()
d = R(s_x) + n
R_p = R * FFT * A
j = R_p.adjoint(N.inverse(d))
D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse
N_samps = 200
N_iter = 100
IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
m_x = sky(m)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter,
sampling_inverter=inverter)
samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
plt.plot(s_x.to_global_data(), label="original")
plt.plot(m_x.to_global_data(), label="reconstruction")
plt.legend()
plt.savefig('Krylov_reconstruction.png')
plt.close()
pltdict = {'alpha': .3, 'linewidth': .2}
for i in range(N_samps):
if i == 0:
plt.plot(sky(samps_old[i]).to_global_data(), color='b',
label='Traditional samples (residuals)',
**pltdict)
plt.plot(sky(samps[i]).to_global_data(), color='r',
label='Krylov samples (residuals)',
**pltdict)
else:
plt.plot(sky(samps_old[i]).to_global_data(), color='b', **pltdict)
plt.plot(sky(samps[i]).to_global_data(), color='r', **pltdict)
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_samples_residuals.png')
plt.close()
D_hat_old = ift.full(x_space, 0.).to_global_data()
D_hat_new = ift.full(x_space, 0.).to_global_data()
for i in range(N_samps):
D_hat_old += sky(samps_old[i]).to_global_data()**2
D_hat_new += sky(samps[i]).to_global_data()**2
plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
label='Krylov uncertainty')
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_uncertainty.png')
plt.close()
......@@ -5,5 +5,4 @@ from .nonlinear_power_energy import NonlinearPowerEnergy
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .poisson_energy import PoissonEnergy
from .nonlinearities import Exponential, Linear, Tanh, PositiveTanh
from .krylov_sampling import generate_krylov_samples
from .los_response import LOSResponse
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..minimization.quadratic_energy import QuadraticEnergy
def generate_krylov_samples(D_inv, S, j, N_samps, controller):
"""
Generates inverse samples from a curvature D.
This algorithm iteratively generates samples from
a curvature D by applying conjugate gradient steps
and resampling the curvature in search direction.
Parameters
----------
D_inv : EndomorphicOperator
The curvature which will be the inverse of the covarianc
of the generated samples
S : EndomorphicOperator (from which one can sample)
A prior covariance operator which is used to generate prior
samples that are then iteratively updated
j : Field, optional
A Field to which the inverse of D_inv is applied. The solution
of this matrix inversion problem is a side product of generating
the samples.
If not supplied, it is sampled from the inverse prior.
N_samps : Int
How many samples to generate.
controller : IterationController
convergence controller for the conjugate gradient iteration
Returns
-------
(solution, samples) : A tuple of a field 'solution' and a list of fields
'samples'. The first entry of the tuple is the solution x to
D_inv(x) = j
and the second entry are a list of samples from D_inv.inverse
"""
# RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j
energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy)
if status != controller.CONTINUE:
return energy.position, y
r = energy.gradient
d = r.copy()
previous_gamma = r.vdot(r).real
if previous_gamma == 0:
return energy.position, y
while True:
q = energy.curvature(d)
ddotq = d.vdot(q).real
if ddotq == 0.:
logger.error("Error: ConjugateGradient: ddotq==0.")
return energy.position, y
alpha = previous_gamma/ddotq
if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.")
return energy.position, y
for i in range(len(y)):
y[i] += (np.random.randn()*np.sqrt(ddotq) - y[i].vdot(q))/ddotq * d
q *= -alpha
r = r + q
energy = energy.at_with_grad(energy.position - alpha*d, r)
gamma = r.vdot(r).real
if gamma == 0:
return energy.position, y
status = controller.check(energy)
if status != controller.CONTINUE:
return energy.position, y
d *= max(0, gamma/previous_gamma)
d += r
previous_gamma = gamma
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment