From 4a35010ca362c855efb992b334bd129d25d31bb7 Mon Sep 17 00:00:00 2001
From: Reimar Leike <reimar@leike.name>
Date: Mon, 4 Jun 2018 15:32:45 +0200
Subject: [PATCH] introduced a new sampling routine that does not have the
 systematic errors of the old krylov samples but is slower

---
 demos/krylov_sampling.py          |   5 +-
 nifty4/library/krylov_sampling.py | 121 +++++++++++++++---------------
 2 files changed, 65 insertions(+), 61 deletions(-)

diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py
index 4987437ae..afb0822f2 100644
--- a/demos/krylov_sampling.py
+++ b/demos/krylov_sampling.py
@@ -37,10 +37,11 @@ D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse
 N_samps = 200
 N_iter = 100
 IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
-m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
-m_x = sky(m)
+samps = ift.library.generate_krylov_samples(D_inv, S, N_samps, IC)
 inverter = ift.ConjugateGradient(IC)
 curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
+m = curv.inverse_times(j)
+m_x = sky(m)
 samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
 
 plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
diff --git a/nifty4/library/krylov_sampling.py b/nifty4/library/krylov_sampling.py
index d83a8a8f2..3196d0c52 100644
--- a/nifty4/library/krylov_sampling.py
+++ b/nifty4/library/krylov_sampling.py
@@ -20,26 +20,23 @@ import numpy as np
 from ..minimization.quadratic_energy import QuadraticEnergy
 
 
-def generate_krylov_samples(D_inv, S, j, N_samps, controller):
+def generate_krylov_samples(D_inv, S, N_samps, controller):
     """
     Generates inverse samples from a curvature D.
     This algorithm iteratively generates samples from
     a curvature D by applying conjugate gradient steps
     and resampling the curvature in search direction.
+    It is basically just a more stable version of
+    Wiener Filter samples
 
     Parameters
     ----------
-    D_inv : EndomorphicOperator
-        The curvature which will be the inverse of the covarianc
+    D_inv : WienerFilterCurvature
+        The curvature which will be the inverse of the covariance
         of the generated samples
     S : EndomorphicOperator (from which one can sample)
         A prior covariance operator which is used to generate prior
         samples that are then iteratively updated
-    j : Field, optional
-        A Field to which the inverse of D_inv is applied. The solution
-        of this matrix inversion problem is a side product of generating
-        the samples.
-        If not supplied, it is sampled from the inverse prior.
     N_samps : Int
         How many samples to generate.
     controller : IterationController
@@ -47,56 +44,62 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
 
     Returns
     -------
-    (solution, samples) : A tuple of a field 'solution' and a list of fields
-        'samples'. The first entry of the tuple is the solution x to
-            D_inv(x) = j
-        and the second entry are a list of samples from D_inv.inverse
+    samples : a list of samples from D_inv.inverse
     """
-    # RL FIXME: make consistent with complex numbers
-    j = S.draw_sample(from_inverse=True) if j is None else j
-    energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
-    y = [S.draw_sample() for _ in range(N_samps)]
-
-    status = controller.start(energy)
-    if status != controller.CONTINUE:
-        return energy.position, y
-
-    r = energy.gradient
-    d = r.copy()
-
-    previous_gamma = r.vdot(r).real
-    if previous_gamma == 0:
-        return energy.position, y
-
-    while True:
-        q = energy.curvature(d)
-        ddotq = d.vdot(q).real
-        if ddotq == 0.:
-            logger.error("Error: ConjugateGradient: ddotq==0.")
-            return energy.position, y
-        alpha = previous_gamma/ddotq
-
-        if alpha < 0:
-            logger.error("Error: ConjugateGradient: alpha<0.")
-            return energy.position, y
-
-        for i in range(len(y)):
-            y[i] += (np.random.randn()*np.sqrt(ddotq) - y[i].vdot(q))/ddotq * d
-
-        q *= -alpha
-        r = r + q
-
-        energy = energy.at_with_grad(energy.position - alpha*d, r)
-
-        gamma = r.vdot(r).real
-        if gamma == 0:
-            return energy.position, y
-
-        status = controller.check(energy)
+    samples = []
+    for i in range(N_samps):
+        x0 = S.draw_sample()
+        y = x0*0
+        j = y*0
+        #j = y
+        energy = QuadraticEnergy(x0, D_inv, j)
+
+        status = controller.start(energy)
         if status != controller.CONTINUE:
-            return energy.position, y
-
-        d *= max(0, gamma/previous_gamma)
-        d += r
-
-        previous_gamma = gamma
+            samples += [y]
+            break
+
+        r = energy.gradient
+        d = r.copy()
+
+        previous_gamma = r.vdot(r).real
+        if previous_gamma == 0:
+            samples += [y+energy.position]
+            break
+
+        while True:
+            q = energy.curvature(d)
+            ddotq = d.vdot(q).real
+            if ddotq == 0.:
+                logger.error("Error: ConjugateGradient: ddotq==0.")
+                samples += [y+energy.position]
+                break
+            alpha = previous_gamma/ddotq
+
+            if alpha < 0:
+                logger.error("Error: ConjugateGradient: alpha<0.")
+                samples += [y+energy.position]
+                break
+    
+            y += (np.random.randn()*np.sqrt(ddotq) )/ddotq * d
+
+            q *= -alpha
+            r = r + q
+
+            energy = energy.at_with_grad(energy.position - alpha*d, r)
+
+            gamma = r.vdot(r).real
+            if gamma == 0:
+                samples += [y+energy.position]
+                break
+
+            status = controller.check(energy)
+            if status != controller.CONTINUE:
+                samples += [y+energy.position]
+                break
+
+            d *= max(0, gamma/previous_gamma)
+            d += r
+
+            previous_gamma = gamma
+    return samples
-- 
GitLab