From b4d6eceb79f3279d506e88666ca3d4bb63db47e9 Mon Sep 17 00:00:00 2001
From: Reimar Leike <reimar@leike.name>
Date: Wed, 18 Jul 2018 17:41:48 +0200
Subject: [PATCH] added more consisteny checks, fixed a bug where SlopeOperator
 would not respect dtype and NNifty is now exporting NullOperator

---
 nifty5/__init__.py                  |  1 +
 nifty5/operators/slope_operator.py  |  2 +-
 test/test_operators/test_adjoint.py | 72 ++++++++++++++++++++++++++++-
 3 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index 698d8d499..256a6e096 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -42,6 +42,7 @@ from .operators.laplace_operator import LaplaceOperator
 from .operators.linear_operator import LinearOperator
 from .operators.mask_operator import MaskOperator
 from .operators.multi_adaptor import MultiAdaptor
+from .operators.null_operator import NullOperator
 from .operators.power_distributor import PowerDistributor
 from .operators.qht_operator import QHTOperator
 from .operators.sampling_enabler import SamplingEnabler
diff --git a/nifty5/operators/slope_operator.py b/nifty5/operators/slope_operator.py
index 1dd6d055a..dadee40c0 100644
--- a/nifty5/operators/slope_operator.py
+++ b/nifty5/operators/slope_operator.py
@@ -76,7 +76,7 @@ class SlopeOperator(LinearOperator):
             return Field.from_global_data(self.target, res)
 
         # Adjoint times
-        res = np.zeros(self.domain[0].shape)
+        res = np.zeros(self.domain[0].shape, dtype=x.dtype)
         xglob = x.to_global_data()
         res[-1] = np.sum(xglob) * self._sigmas[-1]
         for i in range(self.ndim):
diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py
index bdbe13de6..20b374962 100644
--- a/test/test_operators/test_adjoint.py
+++ b/test/test_operators/test_adjoint.py
@@ -35,7 +35,77 @@ _pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))]
 
 
 class Consistency_Tests(unittest.TestCase):
-    @expand(product(_h_spaces + _p_RG_spaces + _p_spaces
+    @expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
+                     (ift.RGSpace((24, 31), distances=(0.4, 2.34),
+                                  harmonic=True), 3, 0),
+                     (ift.LMSpace(4), 10, 0)],
+                    [np.float64, np.complex128]))
+    def testSlopeOperator(self, args, dtype):
+        tmp = ift.ExpTransform(ift.PowerSpace(args[0]), args[1], args[2])
+        tgt = tmp.domain[0]
+        sig = np.array([0.3, 0.13])
+        dom = ift.UnstructuredDomain(2)
+        op = ift.SlopeOperator(dom, tgt, sig)
+        ift.extra.consistency_check(op, dtype, dtype)
+
+    @expand(product(_h_spaces + _p_spaces + _pow_spaces,
+                    _h_spaces + _p_spaces + _pow_spaces,
+                    [np.float64, np.complex128]))
+    def testSelectionOperator(self, sp1, sp2, dtype):
+        mdom = ift.MultiDomain.make({'a':sp1, 'b':sp2})
+        op = ift.SelectionOperator(mdom, 'a')
+        ift.extra.consistency_check(op, dtype, dtype)
+
+    @expand(product(_h_spaces + _p_spaces + _pow_spaces,
+                    [np.float64, np.complex128]))
+    def testSandwichOperator(self, sp, dtype):
+        bun = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
+                                                        dtype=dtype))
+        cheese = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
+                                                        dtype=dtype))
+        op = ift.SandwichOperator.make(bun, cheese)
+        ift.extra.consistency_check(op, dtype, dtype)
+
+    @expand(product(_h_spaces + _p_spaces + _pow_spaces,
+                    [np.float64, np.complex128]))
+    def testOperatorAdaptor(self, sp, dtype):
+        op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
+                                                        dtype=dtype))
+        ift.extra.consistency_check(op.adjoint, dtype, dtype)
+        ift.extra.consistency_check(op.inverse, dtype, dtype)
+        ift.extra.consistency_check(op.inverse.adjoint, dtype, dtype)
+        ift.extra.consistency_check(op.adjoint.inverse, dtype, dtype)
+
+    @expand(product(_h_spaces + _p_spaces + _pow_spaces,
+                    _h_spaces + _p_spaces + _pow_spaces,
+                    [np.float64, np.complex128]))
+    def testNullOperator(self, sp1, sp2, dtype):
+        op = ift.NullOperator(sp1, sp2)
+        ift.extra.consistency_check(op, dtype, dtype)
+        mdom1 = ift.MultiDomain.make({'a':sp1})
+        mdom2 = ift.MultiDomain.make({'b':sp2})
+        op = ift.NullOperator(mdom1, mdom2)
+        ift.extra.consistency_check(op, dtype, dtype)
+        op = ift.NullOperator(sp1, mdom2)
+        ift.extra.consistency_check(op, dtype, dtype)
+        op = ift.NullOperator(mdom1, sp2)
+        ift.extra.consistency_check(op, dtype, dtype)
+        
+    @expand(product(_h_spaces + _p_spaces
+                        + _pow_spaces,
+                    [np.float64, np.complex128]))
+    def testMultiAdaptor(self, sp, dtype):
+        mdom = ift.MultiDomain.make({'a':sp})
+        op = ift.MultiAdaptor(mdom)
+        ift.extra.consistency_check(op, dtype, dtype)
+
+    @expand(product(_p_RG_spaces,
+                    [np.float64, np.complex128]))
+    def testHarmonicSmoothingOperator(self, sp, dtype):
+        op = ift.HarmonicSmoothingOperator(sp, 0.1)
+        ift.extra.consistency_check(op, dtype, dtype)
+
+    @expand(product(_h_spaces + _p_spaces
                         + _pow_spaces,
                     [np.float64, np.complex128]))
     def testDOFDistributor(self, sp, dtype):
-- 
GitLab