From e9a5b0f18b33c1e16238721eebd815b54f98efec Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sun, 1 Apr 2018 13:55:42 +0200
Subject: [PATCH] make inverse_draw_sample() largely obsolete

---
 nifty4/operators/diagonal_operator.py    | 17 +++++------------
 nifty4/operators/endomorphic_operator.py | 15 +++++++++------
 nifty4/operators/inversion_enabler.py    | 17 ++++++-----------
 nifty4/operators/operator_adapter.py     | 11 +++--------
 nifty4/operators/sandwich_operator.py    |  7 +++++--
 nifty4/operators/scaling_operator.py     | 10 +++-------
 nifty4/operators/sum_operator.py         |  8 +++++---
 7 files changed, 36 insertions(+), 49 deletions(-)

diff --git a/nifty4/operators/diagonal_operator.py b/nifty4/operators/diagonal_operator.py
index c2d0b5dd7..9c3652226 100644
--- a/nifty4/operators/diagonal_operator.py
+++ b/nifty4/operators/diagonal_operator.py
@@ -174,21 +174,14 @@ class DiagonalOperator(EndomorphicOperator):
             raise ValueError("bad operator flipping mode")
         return res
 
-    def draw_sample(self, dtype=np.float64):
+    def draw_sample(self, from_inverse=False, dtype=np.float64):
         if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
                 (self._ldiag <= 0.).any()):
             raise ValueError("operator not positive definite")
         res = Field.from_random(random_type="normal", domain=self._domain,
                                 dtype=dtype)
-        res.local_data[()] *= np.sqrt(self._ldiag)
-        return res
-
-    def inverse_draw_sample(self, dtype=np.float64):
-        if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
-                (self._ldiag <= 0.).any()):
-            raise ValueError("operator not positive definite")
-
-        res = Field.from_random(random_type="normal", domain=self._domain,
-                                dtype=dtype)
-        res.local_data[()] /= np.sqrt(self._ldiag)
+        if from_inverse:
+            res.local_data[()] /= np.sqrt(self._ldiag)
+        else:
+            res.local_data[()] *= np.sqrt(self._ldiag)
         return res
diff --git a/nifty4/operators/endomorphic_operator.py b/nifty4/operators/endomorphic_operator.py
index e456923c0..0046d8986 100644
--- a/nifty4/operators/endomorphic_operator.py
+++ b/nifty4/operators/endomorphic_operator.py
@@ -36,12 +36,19 @@ class EndomorphicOperator(LinearOperator):
         for endomorphic operators."""
         return self.domain
 
-    def draw_sample(self, dtype=np.float64):
+    def draw_sample(self, from_inverse=False, dtype=np.float64):
         """Generate a zero-mean sample
 
         Generates a sample from a Gaussian distribution with zero mean and
         covariance given by the operator.
 
+        Parameters
+        ----------
+        from_inverse : bool (default : False)
+            if True, the sample is drawn from the inverse of the operator
+        dtype : numpy datatype (default : numpy.float64)
+            the data type to be used for the sample
+
         Returns
         -------
         Field
@@ -59,8 +66,4 @@ class EndomorphicOperator(LinearOperator):
         -------
             A sample from the Gaussian of given covariance
         """
-        if self.capability & self.INVERSE_TIMES:
-            x = self.draw_sample(dtype)
-            return self.inverse_times(x)
-        else:
-            raise NotImplementedError
+        return self.draw_sample(True, dtype)
diff --git a/nifty4/operators/inversion_enabler.py b/nifty4/operators/inversion_enabler.py
index 774318cb8..dbdd59314 100644
--- a/nifty4/operators/inversion_enabler.py
+++ b/nifty4/operators/inversion_enabler.py
@@ -20,11 +20,11 @@ from ..minimization.quadratic_energy import QuadraticEnergy
 from ..minimization.iteration_controller import IterationController
 from ..field import Field
 from ..logger import logger
-from .linear_operator import LinearOperator
+from .endomorphic_operator import EndomorphicOperator
 import numpy as np
 
 
-class InversionEnabler(LinearOperator):
+class InversionEnabler(EndomorphicOperator):
     """Class which augments the capability of another operator object via
     numerical inversion.
 
@@ -80,14 +80,9 @@ class InversionEnabler(LinearOperator):
             logger.warning("Error detected during operator inversion")
         return r.position
 
-    def draw_sample(self, dtype=np.float64):
+    def draw_sample(self, from_inverse=False, dtype=np.float64):
         try:
-            return self._op.draw_sample(dtype)
+            return self._op.draw_sample(from_inverse, dtype)
         except:
-            return self(self._op.inverse_draw_sample(dtype))
-
-    def inverse_draw_sample(self, dtype=np.float64):
-        try:
-            return self._op.inverse_draw_sample(dtype)
-        except:
-            return self.inverse_times(self._op.draw_sample(dtype))
+            samp = self._op.draw_sample(not from_inverse, dtype)
+            return self.inverse_times(samp) if from_inverse else self(samp)
diff --git a/nifty4/operators/operator_adapter.py b/nifty4/operators/operator_adapter.py
index 61ce31865..036045e79 100644
--- a/nifty4/operators/operator_adapter.py
+++ b/nifty4/operators/operator_adapter.py
@@ -49,12 +49,7 @@ class OperatorAdapter(LinearOperator):
     def apply(self, x, mode):
         return self._op.apply(x, self._modeTable[self._mode][self._ilog[mode]])
 
-    def draw_sample(self, dtype=np.float64):
+    def draw_sample(self, from_inverse=False, dtype=np.float64):
         if self._mode & self.INVERSE_BIT:
-            return self._op.inverse_draw_sample(dtype)
-        return self._op.draw_sample(dtype)
-
-    def inverse_draw_sample(self, dtype=np.float64):
-        if self._mode & self.INVERSE_BIT:
-            return self._op.draw_sample(dtype)
-        return self._op.inverse_draw_sample(dtype)
+            return self._op.draw_sample(not from_inverse, dtype)
+        return self._op.draw_sample(from_inverse, dtype)
diff --git a/nifty4/operators/sandwich_operator.py b/nifty4/operators/sandwich_operator.py
index e324b135d..464072adb 100644
--- a/nifty4/operators/sandwich_operator.py
+++ b/nifty4/operators/sandwich_operator.py
@@ -48,5 +48,8 @@ class SandwichOperator(EndomorphicOperator):
     def apply(self, x, mode):
         return self._op.apply(x, mode)
 
-    def draw_sample(self, dtype=np.float64):
-        return self._bun.adjoint_times(self._cheese.draw_sample(dtype))
+    def draw_sample(self, from_inverse=False, dtype=np.float64):
+        if from_inverse:
+            raise ValueError("cannot draw from inverse of this operator")
+        return self._bun.adjoint_times(
+            self._cheese.draw_sample(from_inverse, dtype))
diff --git a/nifty4/operators/scaling_operator.py b/nifty4/operators/scaling_operator.py
index ddb8eaaae..cce8fee4b 100644
--- a/nifty4/operators/scaling_operator.py
+++ b/nifty4/operators/scaling_operator.py
@@ -93,14 +93,10 @@ class ScalingOperator(EndomorphicOperator):
     def capability(self):
         return self._all_ops
 
-    def _sample_helper(self, fct, dtype):
+    def draw_sample(self, from_inverse=False, dtype=np.float64):
+        fct = self._factor
         if fct.imag != 0. or fct.real <= 0.:
             raise ValueError("operator not positive definite")
+        fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
         return Field.from_random(
            random_type="normal", domain=self._domain, std=fct, dtype=dtype)
-
-    def draw_sample(self, dtype=np.float64):
-        return self._sample_helper(np.sqrt(self._factor), dtype)
-
-    def inverse_draw_sample(self, dtype=np.float64):
-        return self._sample_helper(1./np.sqrt(self._factor), dtype)
diff --git a/nifty4/operators/sum_operator.py b/nifty4/operators/sum_operator.py
index be36836af..e2d5d70f3 100644
--- a/nifty4/operators/sum_operator.py
+++ b/nifty4/operators/sum_operator.py
@@ -143,8 +143,10 @@ class SumOperator(LinearOperator):
                     res += op.apply(x, mode)
         return res
 
-    def draw_sample(self, dtype=np.float64):
-        res = self._ops[0].draw_sample(dtype)
+    def draw_sample(self, from_inverse=False, dtype=np.float64):
+        if from_inverse:
+            raise ValueError("cannot draw from inverse of this operator")
+        res = self._ops[0].draw_sample(from_inverse, dtype)
         for op in self._ops[1:]:
-            res += op.draw_sample(dtype)
+            res += op.draw_sample(from_inverse, dtype)
         return res
-- 
GitLab