Commit 329b16b8 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix Krylov code

parent 547df672
Pipeline #27866 passed with stages
in 13 minutes and 38 seconds
......@@ -9,11 +9,12 @@ x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()
d_space = x_space
N_hat = ift.Field(d_space, 10.)
N_hat.val[400:450] = 0.0001
N = ift.DiagonalOperator(N_hat, d_space)
N_hat = np.full(d_space.shape, 10.)
N_hat[400:450] = 0.0001
N_hat = ift.Field.from_global_data(d_space, N_hat)
N = ift.DiagonalOperator(N_hat)
FFT = ift.HarmonicTransformOperator(h_space, target=x_space)
FFT = ift.HarmonicTransformOperator(h_space, x_space)
R = ift.ScalingOperator(1., x_space)
......@@ -34,19 +35,17 @@ D_inv = ift.SandwichOperator(R_p, N.inverse) + S.inverse
N_samps = 200
N_iter = 10
m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, N_iter)
N_iter = 100
IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
m_x = sky(m)
IC = ift.GradientNormController(iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
samps_old = []
for i in range(N_samps):
samps_old += [curv.draw_sample(from_inverse=True)]
samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
plt.plot(d.val, '+', label="data", alpha=.5)
plt.plot(s_x.val, label="original")
plt.plot(m_x.val, label="reconstruction")
plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
plt.plot(s_x.to_global_data(), label="original")
plt.plot(m_x.to_global_data(), label="reconstruction")
plt.legend()
plt.savefig('Krylov_reconstruction.png')
plt.close()
......@@ -54,32 +53,31 @@ plt.close()
pltdict = {'alpha': .3, 'linewidth': .2}
for i in range(N_samps):
if i == 0:
plt.plot(sky(samps_old[i]).val, color='b',
plt.plot(sky(samps_old[i]).to_global_data(), color='b',
label='Traditional samples (residuals)',
**pltdict)
plt.plot(sky(samps[i]).val, color='r',
plt.plot(sky(samps[i]).to_global_data(), color='r',
label='Krylov samples (residuals)',
**pltdict)
else:
plt.plot(sky(samps_old[i]).val, color='b', **pltdict)
plt.plot(sky(samps[i]).val, color='r', **pltdict)
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
plt.plot(sky(samps_old[i]).to_global_data(), color='b', **pltdict)
plt.plot(sky(samps[i]).to_global_data(), color='r', **pltdict)
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_samples_residuals.png')
plt.close()
D_hat_old = ift.Field.zeros(x_space).val
D_hat_new = ift.Field.zeros(x_space).val
D_hat_old = ift.Field.zeros(x_space).to_global_data()
D_hat_new = ift.Field.zeros(x_space).to_global_data()
for i in range(N_samps):
D_hat_old += sky(samps_old[i]).val**2
D_hat_new += sky(samps[i]).val**2
D_hat_old += sky(samps_old[i]).to_global_data()**2
D_hat_new += sky(samps[i]).to_global_data()**2
plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
label='Krylov uncertainty')
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_uncertainty.png')
plt.close()
......@@ -16,12 +16,12 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from numpy import sqrt
from numpy.random import randn
import numpy as np
from ..field import Field
from ..minimization.quadratic_energy import QuadraticEnergy
def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
name=None):
def generate_krylov_samples(D_inv, S, j, N_samps, controller):
"""
Generates inverse samples from a curvature D.
This algorithm iteratively generates samples from
......@@ -41,10 +41,10 @@ def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
of this matrix inversion problem is a side product of generating
the samples.
If not supplied, it is sampled from the inverse prior.
N_samps : Int, optional
How many samples to generate. Default: 1
N_iter : Int, optional
How many iterations of the conjugate gradient to run. Default: 10
N_samps : Int
How many samples to generate.
controller : IterationController
convergence controller for the conjugate gradient iteration
Returns
-------
......@@ -54,19 +54,30 @@ def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
and the second entry are a list of samples from D_inv.inverse
"""
j = S.draw_sample(from_inverse=True) if j is None else j
x = j*0
x = Field.full(D_inv.domain, 0.)
energy = QuadraticEnergy(x, D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy)
if status != controller.CONTINUE:
return x, y
r = j.copy()
p = r.copy()
d = p.vdot(D_inv(p))
y = [S.draw_sample() for _ in range(N_samps)]
for k in range(1, 1+N_iter):
while True:
gamma = r.vdot(r)/d
if gamma == 0.:
break
x += gamma*p
for i in range(N_samps):
y[i] -= p.vdot(D_inv(y[i])) * p / d
y[i] += randn() / sqrt(d) * p
x = x + gamma*p
for samp in y:
samp -= p.vdot(D_inv(samp)) * p / d
samp += np.random.randn() / np.sqrt(d) * p
energy = energy.at(x)
status = controller.check(energy)
if status != controller.CONTINUE:
return x, y
r_new = r - gamma * D_inv(p)
beta = r_new.vdot(r_new) / r.vdot(r)
r = r_new
......@@ -74,6 +85,4 @@ def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
d = p.vdot(D_inv(p))
if d == 0.:
break
if name is not None:
print('{}: Iteration #{}'.format(name, k))
return x, y
......@@ -74,7 +74,7 @@ class InversionEnabler(EndomorphicOperator):
prec = self._approximation
if prec is not None:
prec = prec._flip_modes(self._ilog[mode])
energy = QuadraticEnergy(A=invop, b=x, position=x0)
energy = QuadraticEnergy(x0, invop, x)
r, stat = self._inverter(energy, preconditioner=prec)
if stat != IterationController.CONVERGED:
logger.warning("Error detected during operator inversion")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment