From c59be7601e218add56e359a01c824798d98bf2d6 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Tue, 15 Oct 2019 17:53:44 +0200
Subject: [PATCH] Make Sampling enabler more general

---
 nifty5/operators/sampling_enabler.py | 21 ++++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/nifty5/operators/sampling_enabler.py b/nifty5/operators/sampling_enabler.py
index 2ceec16e2..883b8f3f6 100644
--- a/nifty5/operators/sampling_enabler.py
+++ b/nifty5/operators/sampling_enabler.py
@@ -42,16 +42,20 @@ class SamplingEnabler(EndomorphicOperator):
         operator, which supports the operation modes that the operator doesn't
         have. It is used as a preconditioner during the iterative inversion,
         to accelerate convergence.
+    start_from_zero : boolean
+        If true, the conjugate gradient algorithm starts from a field filled
+        with zeros. Otherwise, it starts from a prior samples. Default is
+        False.
     """
 
     def __init__(self, likelihood, prior, iteration_controller,
-                 approximation=None):
-        self._op = likelihood + prior
-        # FIXME Separation in likelihood and prior not necessary
+                 approximation=None, start_from_zero=False):
         self._likelihood = likelihood
         self._prior = prior
         self._ic = iteration_controller
         self._approximation = approximation
+        self._start_from_zero = bool(start_from_zero)
+        self._op = likelihood + prior
         self._domain = self._op.domain
         self._capability = self._op.capability
 
@@ -61,8 +65,15 @@ class SamplingEnabler(EndomorphicOperator):
         except NotImplementedError:
             if not from_inverse:
                 raise ValueError("from_inverse must be True here")
-            b = self._op.draw_sample()
-            energy = QuadraticEnergy(0*b, self._op, b)
+            if self._start_from_zero:
+                b = self._op.draw_sample()
+                energy = QuadraticEnergy(0*b, self._op, b)
+            else:
+                s = self._prior.draw_sample(from_inverse=True)
+                sp = self._prior(s)
+                nj = self._likelihood.draw_sample()
+                energy = QuadraticEnergy(s, self._op, sp + nj,
+                                         _grad=self._likelihood(s) - nj)
             inverter = ConjugateGradient(self._ic)
             if self._approximation is not None:
                 energy, convergence = inverter(
-- 
GitLab