From 329b16b85fa877f93f51959d0107dbd6d6edfaa1 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Mon, 23 Apr 2018 22:18:19 +0200
Subject: [PATCH] fix Krylov code

---
 demos/krylov_sampling.py              | 46 +++++++++++++--------------
 nifty4/library/krylov_sampling.py     | 41 ++++++++++++++----------
 nifty4/operators/inversion_enabler.py |  2 +-
 3 files changed, 48 insertions(+), 41 deletions(-)

diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py
index 8976cc2ba..7b316e4db 100644
--- a/demos/krylov_sampling.py
+++ b/demos/krylov_sampling.py
@@ -9,11 +9,12 @@ x_space = ift.RGSpace(1024)
 h_space = x_space.get_default_codomain()
 
 d_space = x_space
-N_hat = ift.Field(d_space, 10.)
-N_hat.val[400:450] = 0.0001
-N = ift.DiagonalOperator(N_hat, d_space)
+N_hat = np.full(d_space.shape, 10.)
+N_hat[400:450] = 0.0001
+N_hat = ift.Field.from_global_data(d_space, N_hat)
+N = ift.DiagonalOperator(N_hat)
 
-FFT = ift.HarmonicTransformOperator(h_space, target=x_space)
+FFT = ift.HarmonicTransformOperator(h_space, x_space)
 R = ift.ScalingOperator(1., x_space)
 
 
@@ -34,19 +35,17 @@ D_inv = ift.SandwichOperator(R_p, N.inverse) + S.inverse
 
 
 N_samps = 200
-N_iter = 10
-m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, N_iter)
+N_iter = 100
+IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
+m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
 m_x = sky(m)
-IC = ift.GradientNormController(iteration_limit=N_iter)
 inverter = ift.ConjugateGradient(IC)
 curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
-samps_old = []
-for i in range(N_samps):
-    samps_old += [curv.draw_sample(from_inverse=True)]
+samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
 
-plt.plot(d.val, '+', label="data", alpha=.5)
-plt.plot(s_x.val, label="original")
-plt.plot(m_x.val, label="reconstruction")
+plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
+plt.plot(s_x.to_global_data(), label="original")
+plt.plot(m_x.to_global_data(), label="reconstruction")
 plt.legend()
 plt.savefig('Krylov_reconstruction.png')
 plt.close()
@@ -54,32 +53,31 @@ plt.close()
 pltdict = {'alpha': .3, 'linewidth': .2}
 for i in range(N_samps):
     if i == 0:
-        plt.plot(sky(samps_old[i]).val, color='b',
+        plt.plot(sky(samps_old[i]).to_global_data(), color='b',
                  label='Traditional samples (residuals)',
                  **pltdict)
-        plt.plot(sky(samps[i]).val, color='r',
+        plt.plot(sky(samps[i]).to_global_data(), color='r',
                  label='Krylov samples (residuals)',
                  **pltdict)
     else:
-        plt.plot(sky(samps_old[i]).val, color='b', **pltdict)
-        plt.plot(sky(samps[i]).val, color='r', **pltdict)
-plt.plot((s_x - m_x).val, color='k', label='signal - mean')
+        plt.plot(sky(samps_old[i]).to_global_data(), color='b', **pltdict)
+        plt.plot(sky(samps[i]).to_global_data(), color='r', **pltdict)
+plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
 plt.legend()
 plt.savefig('Krylov_samples_residuals.png')
 plt.close()
 
-D_hat_old = ift.Field.zeros(x_space).val
-D_hat_new = ift.Field.zeros(x_space).val
+D_hat_old = ift.Field.zeros(x_space).to_global_data()
+D_hat_new = ift.Field.zeros(x_space).to_global_data()
 for i in range(N_samps):
-    D_hat_old += sky(samps_old[i]).val**2
-    D_hat_new += sky(samps[i]).val**2
+    D_hat_old += sky(samps_old[i]).to_global_data()**2
+    D_hat_new += sky(samps[i]).to_global_data()**2
 plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
 plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
 plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
     D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
     label='Krylov uncertainty')
-plt.plot((s_x - m_x).val, color='k', label='signal - mean')
+plt.plot((s_x - m_x).to_global_data(), color='k', label='signal - mean')
 plt.legend()
 plt.savefig('Krylov_uncertainty.png')
 plt.close()
-
diff --git a/nifty4/library/krylov_sampling.py b/nifty4/library/krylov_sampling.py
index e592a0e92..c264af375 100644
--- a/nifty4/library/krylov_sampling.py
+++ b/nifty4/library/krylov_sampling.py
@@ -16,12 +16,12 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from numpy import sqrt
-from numpy.random import randn
+import numpy as np
+from ..field import Field
+from ..minimization.quadratic_energy import QuadraticEnergy
 
 
-def generate_krylov_samples(D_inv, S, j=None,  N_samps=1, N_iter=10,
-                            name=None):
+def generate_krylov_samples(D_inv, S, j, N_samps, controller):
     """
     Generates inverse samples from a curvature D.
     This algorithm iteratively generates samples from
@@ -41,10 +41,10 @@ def generate_krylov_samples(D_inv, S, j=None,  N_samps=1, N_iter=10,
         of this matrix inversion problem is a side product of generating
         the samples.
         If not supplied, it is sampled from the inverse prior.
-    N_samps : Int, optional
-        How many samples to generate. Default: 1
-    N_iter : Int, optional
-        How many iterations of the conjugate gradient to run. Default: 10
+    N_samps : Int
+        How many samples to generate.
+    controller : IterationController
+        convergence controller for the conjugate gradient iteration
 
     Returns
     -------
@@ -54,19 +54,30 @@ def generate_krylov_samples(D_inv, S, j=None,  N_samps=1, N_iter=10,
         and the second entry are a list of samples from D_inv.inverse
     """
     j = S.draw_sample(from_inverse=True) if j is None else j
-    x = j*0
+    x = Field.full(D_inv.domain, 0.)
+    energy = QuadraticEnergy(x, D_inv, j)
+    y = [S.draw_sample() for _ in range(N_samps)]
+
+    status = controller.start(energy)
+    if status != controller.CONTINUE:
+        return x, y
+
     r = j.copy()
     p = r.copy()
     d = p.vdot(D_inv(p))
     y = [S.draw_sample() for _ in range(N_samps)]
-    for k in range(1, 1+N_iter):
+    while True:
         gamma = r.vdot(r)/d
         if gamma == 0.:
             break
-        x += gamma*p
-        for i in range(N_samps):
-            y[i] -= p.vdot(D_inv(y[i])) * p / d
-            y[i] += randn() / sqrt(d) * p
+        x = x + gamma*p
+        for samp in y:
+            samp -= p.vdot(D_inv(samp)) * p / d
+            samp += np.random.randn() / np.sqrt(d) * p
+        energy = energy.at(x)
+        status = controller.check(energy)
+        if status != controller.CONTINUE:
+            return x, y
         r_new = r - gamma * D_inv(p)
         beta = r_new.vdot(r_new) / r.vdot(r)
         r = r_new
@@ -74,6 +85,4 @@ def generate_krylov_samples(D_inv, S, j=None,  N_samps=1, N_iter=10,
         d = p.vdot(D_inv(p))
         if d == 0.:
             break
-        if name is not None:
-            print('{}: Iteration #{}'.format(name, k))
     return x, y
diff --git a/nifty4/operators/inversion_enabler.py b/nifty4/operators/inversion_enabler.py
index dbdd59314..91c5a380f 100644
--- a/nifty4/operators/inversion_enabler.py
+++ b/nifty4/operators/inversion_enabler.py
@@ -74,7 +74,7 @@ class InversionEnabler(EndomorphicOperator):
         prec = self._approximation
         if prec is not None:
             prec = prec._flip_modes(self._ilog[mode])
-        energy = QuadraticEnergy(A=invop, b=x, position=x0)
+        energy = QuadraticEnergy(x0, invop, x)
         r, stat = self._inverter(energy, preconditioner=prec)
         if stat != IterationController.CONVERGED:
             logger.warning("Error detected during operator inversion")
-- 
GitLab