Commit c305952a authored by Reimar Heinrich Leike's avatar Reimar Heinrich Leike

resolved merge conflict and made krylov sampling faster by caching matrix...

resolved merge conflict and made krylov sampling faster by caching matrix multiplications, thus only one matrix multiplication per iteration is needed
parents 478a37dd 72c1a501
Pipeline #27916 passed with stages
in 11 minutes and 15 seconds
......@@ -9,11 +9,12 @@ x_space = ift.RGSpace(1024)
h_space = x_space.get_default_codomain()
d_space = x_space
N_hat = ift.Field(d_space, 10.)
N_hat.val[400:450] = 0.0001
N = ift.DiagonalOperator(N_hat, d_space)
N_hat = np.full(d_space.shape, 10.)
N_hat[400:450] = 0.0001
N_hat = ift.Field.from_global_data(d_space, N_hat)
N = ift.DiagonalOperator(N_hat)
FFT = ift.HarmonicTransformOperator(h_space, target=x_space)
FFT = ift.HarmonicTransformOperator(h_space, x_space)
R = ift.ScalingOperator(1., x_space)
......@@ -34,19 +35,17 @@ D_inv = ift.SandwichOperator(R_p, N.inverse) + S.inverse
N_samps = 200
N_iter = 10
m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, N_iter)
N_iter = 100
IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
m_x = sky(m)
IC = ift.GradientNormController(iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
samps_old = []
for i in range(N_samps):
samps_old += [curv.draw_sample(from_inverse=True)]
samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
plt.plot(d.val, '+', label="data", alpha=.5)
plt.plot(s_x.val, label="original")
plt.plot(m_x.val, label="reconstruction")
plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
plt.plot(s_x.to_global_data(), label="original")
plt.plot(m_x.to_global_data(), label="reconstruction")
plt.legend()
plt.savefig('Krylov_reconstruction.png')
plt.close()
......@@ -54,32 +53,31 @@ plt.close()
pltdict = {'alpha': .3, 'linewidth': .2}
for i in range(N_samps):
if i == 0:
plt.plot(sky(samps_old[i]).val, color='b',
plt.plot(sky(samps_old[i]).to_global_data(), color='b',
label='Traditional samples (residuals)',
**pltdict)
plt.plot(sky(samps[i]).val, color='r',
plt.plot(sky(samps[i]).to_global_data(), color='r',
label='Krylov samples (residuals)',
**pltdict)
else:
plt.plot(sky(samps_old[i]).val, color='b', **pltdict)
plt.plot(sky(samps[i]).val, color='r', **pltdict)
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
plt.plot(sky(samps_old[i]).to_global_data(), color='b', **pltdict)
plt.plot(sky(samps[i]).to_global_data(), color='r', **pltdict)
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_samples_residuals.png')
plt.close()
D_hat_old = ift.Field.zeros(x_space).val
D_hat_new = ift.Field.zeros(x_space).val
D_hat_old = ift.Field.zeros(x_space).to_global_data()
D_hat_new = ift.Field.zeros(x_space).to_global_data()
for i in range(N_samps):
D_hat_old += sky(samps_old[i]).val**2
D_hat_new += sky(samps[i]).val**2
D_hat_old += sky(samps_old[i]).to_global_data()**2
D_hat_new += sky(samps[i]).to_global_data()**2
plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
label='Krylov uncertainty')
plt.plot((s_x - m_x).val, color='k', label='signal - mean')
plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
plt.legend()
plt.savefig('Krylov_uncertainty.png')
plt.close()
......@@ -16,12 +16,11 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from numpy import sqrt
from numpy.random import randn
import numpy as np
from ..minimization.quadratic_energy import QuadraticEnergy
def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
name=None):
def generate_krylov_samples(D_inv, S, j, N_samps, controller):
"""
Generates inverse samples from a curvature D.
This algorithm iteratively generates samples from
......@@ -32,7 +31,7 @@ def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
----------
D_inv : EndomorphicOperator
The curvature which will be the inverse of the covarianc
of the generated samples
of the generated samples
S : EndomorphicOperator (from which one can sample)
A prior covariance operator which is used to generate prior
samples that are then iteratively updated
......@@ -41,10 +40,10 @@ def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
of this matrix inversion problem is a side product of generating
the samples.
If not supplied, it is sampled from the inverse prior.
N_samps : Int, optional
How many samples to generate. Default: 1
N_iter : Int, optional
How many iterations of the conjugate gradient to run. Default: 10
N_samps : Int
How many samples to generate.
controller : IterationController
convergence controller for the conjugate gradient iteration
Returns
-------
......@@ -53,26 +52,36 @@ def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10,
D_inv(x) = j
and the second entry are a list of samples from D_inv.inverse
"""
# MR FIXME: this should be synchronized with the "official" Nifty CG
j = S.draw_sample(from_inverse=True) if j is None else j
x = j*0
x = j*0.
energy = QuadraticEnergy(x, D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy)
if status != controller.CONTINUE:
return x, y
r = j.copy()
p = r.copy()
d = p.vdot(D_inv(p))
y = [S.draw_sample() for _ in range(N_samps)]
for k in range(1, 1+N_iter):
while True:
gamma = r.vdot(r)/d
if gamma == 0.:
break
x += gamma*p
for i in range(N_samps):
y[i] += (randn() * sqrt(d) - p.vdot(D_inv(y[i]))) / d * p
r_new = r - gamma * D_inv(p)
x = x + gamma*p
Dip = D_inv(p)
for samp in y:
samp += (randn() * sqrt(d) - samp.vdot(Dip)) / d * p
energy = energy.at(x)
status = controller.check(energy)
if status != controller.CONTINUE:
return x, y
r_new = r - gamma * Dip
beta = r_new.vdot(r_new) / r.vdot(r)
r = r_new
p = r + beta * p
d = p.vdot(D_inv(p))
d = p.vdot(Dip)
if d == 0.:
break
if name is not None:
print('{}: Iteration #{}'.format(name, k))
return x, y
......@@ -63,6 +63,7 @@ class FFTOperator(LinearOperator):
import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(1000.)
def apply(self, x, mode):
self._check_input(x, mode)
......
......@@ -74,7 +74,7 @@ class InversionEnabler(EndomorphicOperator):
prec = self._approximation
if prec is not None:
prec = prec._flip_modes(self._ilog[mode])
energy = QuadraticEnergy(A=invop, b=x, position=x0)
energy = QuadraticEnergy(x0, invop, x)
r, stat = self._inverter(energy, preconditioner=prec)
if stat != IterationController.CONVERGED:
logger.warning("Error detected during operator inversion")
......
......@@ -160,6 +160,14 @@ def NiftyMetaBase():
return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))
def nthreads():
if nthreads._val is None:
import os
nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1"))
return nthreads._val
nthreads._val = None
def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
......@@ -169,7 +177,7 @@ def hartley(a, axes=None):
raise TypeError("Hartley transform requires real-valued arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes)
tmp = rfftn(a, axes=axes, threads=nthreads())
def _fill_array(tmp, res, axes):
if axes is None:
......@@ -211,7 +219,7 @@ def my_fftn_r2c(a, axes=None):
raise TypeError("Transform requires real-valued input arrays.")
from pyfftw.interfaces.numpy_fft import rfftn
tmp = rfftn(a, axes=axes)
tmp = rfftn(a, axes=axes, threads=nthreads())
def _fill_complex_array(tmp, res, axes):
if axes is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment