From b4d6eceb79f3279d506e88666ca3d4bb63db47e9 Mon Sep 17 00:00:00 2001 From: Reimar Leike <reimar@leike.name> Date: Wed, 18 Jul 2018 17:41:48 +0200 Subject: [PATCH] added more consisteny checks, fixed a bug where SlopeOperator would not respect dtype and NNifty is now exporting NullOperator --- nifty5/__init__.py | 1 + nifty5/operators/slope_operator.py | 2 +- test/test_operators/test_adjoint.py | 72 ++++++++++++++++++++++++++++- 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 698d8d499..256a6e096 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -42,6 +42,7 @@ from .operators.laplace_operator import LaplaceOperator from .operators.linear_operator import LinearOperator from .operators.mask_operator import MaskOperator from .operators.multi_adaptor import MultiAdaptor +from .operators.null_operator import NullOperator from .operators.power_distributor import PowerDistributor from .operators.qht_operator import QHTOperator from .operators.sampling_enabler import SamplingEnabler diff --git a/nifty5/operators/slope_operator.py b/nifty5/operators/slope_operator.py index 1dd6d055a..dadee40c0 100644 --- a/nifty5/operators/slope_operator.py +++ b/nifty5/operators/slope_operator.py @@ -76,7 +76,7 @@ class SlopeOperator(LinearOperator): return Field.from_global_data(self.target, res) # Adjoint times - res = np.zeros(self.domain[0].shape) + res = np.zeros(self.domain[0].shape, dtype=x.dtype) xglob = x.to_global_data() res[-1] = np.sum(xglob) * self._sigmas[-1] for i in range(self.ndim): diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index bdbe13de6..20b374962 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -35,7 +35,77 @@ _pow_spaces = [ift.PowerSpace(ift.RGSpace((17, 38), harmonic=True))] class Consistency_Tests(unittest.TestCase): - @expand(product(_h_spaces + _p_RG_spaces + _p_spaces + @expand(product([(ift.RGSpace(10, harmonic=True), 4, 0), + (ift.RGSpace((24, 31), distances=(0.4, 2.34), + harmonic=True), 3, 0), + (ift.LMSpace(4), 10, 0)], + [np.float64, np.complex128])) + def testSlopeOperator(self, args, dtype): + tmp = ift.ExpTransform(ift.PowerSpace(args[0]), args[1], args[2]) + tgt = tmp.domain[0] + sig = np.array([0.3, 0.13]) + dom = ift.UnstructuredDomain(2) + op = ift.SlopeOperator(dom, tgt, sig) + ift.extra.consistency_check(op, dtype, dtype) + + @expand(product(_h_spaces + _p_spaces + _pow_spaces, + _h_spaces + _p_spaces + _pow_spaces, + [np.float64, np.complex128])) + def testSelectionOperator(self, sp1, sp2, dtype): + mdom = ift.MultiDomain.make({'a':sp1, 'b':sp2}) + op = ift.SelectionOperator(mdom, 'a') + ift.extra.consistency_check(op, dtype, dtype) + + @expand(product(_h_spaces + _p_spaces + _pow_spaces, + [np.float64, np.complex128])) + def testSandwichOperator(self, sp, dtype): + bun = ift.DiagonalOperator(ift.Field.from_random("normal", sp, + dtype=dtype)) + cheese = ift.DiagonalOperator(ift.Field.from_random("normal", sp, + dtype=dtype)) + op = ift.SandwichOperator.make(bun, cheese) + ift.extra.consistency_check(op, dtype, dtype) + + @expand(product(_h_spaces + _p_spaces + _pow_spaces, + [np.float64, np.complex128])) + def testOperatorAdaptor(self, sp, dtype): + op = ift.DiagonalOperator(ift.Field.from_random("normal", sp, + dtype=dtype)) + ift.extra.consistency_check(op.adjoint, dtype, dtype) + ift.extra.consistency_check(op.inverse, dtype, dtype) + ift.extra.consistency_check(op.inverse.adjoint, dtype, dtype) + ift.extra.consistency_check(op.adjoint.inverse, dtype, dtype) + + @expand(product(_h_spaces + _p_spaces + _pow_spaces, + _h_spaces + _p_spaces + _pow_spaces, + [np.float64, np.complex128])) + def testNullOperator(self, sp1, sp2, dtype): + op = ift.NullOperator(sp1, sp2) + ift.extra.consistency_check(op, dtype, dtype) + mdom1 = ift.MultiDomain.make({'a':sp1}) + mdom2 = ift.MultiDomain.make({'b':sp2}) + op = ift.NullOperator(mdom1, mdom2) + ift.extra.consistency_check(op, dtype, dtype) + op = ift.NullOperator(sp1, mdom2) + ift.extra.consistency_check(op, dtype, dtype) + op = ift.NullOperator(mdom1, sp2) + ift.extra.consistency_check(op, dtype, dtype) + + @expand(product(_h_spaces + _p_spaces + + _pow_spaces, + [np.float64, np.complex128])) + def testMultiAdaptor(self, sp, dtype): + mdom = ift.MultiDomain.make({'a':sp}) + op = ift.MultiAdaptor(mdom) + ift.extra.consistency_check(op, dtype, dtype) + + @expand(product(_p_RG_spaces, + [np.float64, np.complex128])) + def testHarmonicSmoothingOperator(self, sp, dtype): + op = ift.HarmonicSmoothingOperator(sp, 0.1) + ift.extra.consistency_check(op, dtype, dtype) + + @expand(product(_h_spaces + _p_spaces + _pow_spaces, [np.float64, np.complex128])) def testDOFDistributor(self, sp, dtype): -- GitLab