diff --git a/nifty4/__init__.py b/nifty4/__init__.py
index 65f44357a125b2db8b04c8e1d03de66368ae7fa6..55f9a9f9b3306d16cfdd72a66e083254ea1033c1 100644
--- a/nifty4/__init__.py
+++ b/nifty4/__init__.py
@@ -53,6 +53,7 @@ from .minimization.line_energy import LineEnergy
from .sugar import *
from .plotting.plot import plot
from . import library
+from . import extra
__all__ = ["Domain", "UnstructuredDomain", "StructuredDomain", "RGSpace", "LMSpace",
"HPSpace", "GLSpace", "DOFSpace", "PowerSpace", "DomainTuple",
diff --git a/nifty4/extra/__init__.py b/nifty4/extra/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0214b80c3664e07d16d584e029a5c8c3ad7ae1c
--- /dev/null
+++ b/nifty4/extra/__init__.py
@@ -0,0 +1 @@
+from .operator_tests import consistency_check
diff --git a/nifty4/extra/operator_tests.py b/nifty4/extra/operator_tests.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ee48d6ce97b58bb96b15d31616e6df2b550947
--- /dev/null
+++ b/nifty4/extra/operator_tests.py
@@ -0,0 +1,65 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+#
+# Copyright(C) 2013-2018 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+import numpy as np
+from ..field import Field
+from .. import dobj
+
+__all__ = ["consistency_check"]
+
+
+def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
+ needed_cap = op.TIMES | op.ADJOINT_TIMES
+ if (op.capability & needed_cap) != needed_cap:
+ return
+ f1 = Field.from_random("normal", op.domain, dtype=domain_dtype)
+ f2 = Field.from_random("normal", op.target, dtype=target_dtype)
+ res1 = f1.vdot(op.adjoint_times(f2))
+ res2 = op.times(f1).vdot(f2)
+ np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
+
+
+def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
+ needed_cap = op.TIMES | op.INVERSE_TIMES
+ if (op.capability & needed_cap) != needed_cap:
+ return
+ foo = Field.from_random("normal", op.target, dtype=target_dtype)
+ res = op(op.inverse_times(foo))
+ np.testing.assert_allclose(dobj.to_global_data(res.val),
+ dobj.to_global_data(foo.val),
+ atol=atol, rtol=rtol)
+
+ foo = Field.from_random("normal", op.domain, dtype=domain_dtype)
+ res = op.inverse_times(op(foo))
+ np.testing.assert_allclose(dobj.to_global_data(res.val),
+ dobj.to_global_data(foo.val),
+ atol=atol, rtol=rtol)
+
+
+def full_implementation(op, domain_dtype, target_dtype, atol, rtol):
+ adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol)
+ inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
+
+
+def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
+ atol=0, rtol=1e-7):
+ full_implementation(op, domain_dtype, target_dtype, atol, rtol)
+ full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol)
+ full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol)
+ full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
+ rtol)
diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py
index e923aa548fee6036bf38e75c89bab5aaf3565235..ec692d245c98aa3b59c1146760796ed85a55c99e 100644
--- a/nifty4/operators/linear_operator.py
+++ b/nifty4/operators/linear_operator.py
@@ -161,5 +161,4 @@ class LinearOperator(with_metaclass(
self._check_mode(mode)
if x.domain != self._dom(mode):
- raise ValueError("The operator's and and field's domains "
- "don't match.")
+ raise ValueError("The operator's and field's domains don't match.")
diff --git a/test/test_energies/test_map.py b/test/test_energies/test_map.py
index bac03da225f7387ca3dcda0d315add77f2f92a75..bbce2de88713c0b58b70500304b5b31708026aa0 100644
--- a/test/test_energies/test_map.py
+++ b/test/test_energies/test_map.py
@@ -259,8 +259,6 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
- print(a.vdot(a))
- print(b.vdot(b))
tol = 1e-7
assert_allclose(ift.dobj.to_global_data(a.val),
ift.dobj.to_global_data(b.val), rtol=tol, atol=tol)
diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py
index 9472dbce3f07c965608fbe37faf0ed38d0c93a7e..0579f857d261c3f7bf4cf63cf318137f92dec043 100644
--- a/test/test_operators/test_adjoint.py
+++ b/test/test_operators/test_adjoint.py
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
-# Copyright(C) 2013-2017 Max-Planck-Society
+# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
@@ -21,23 +21,6 @@ import nifty4 as ift
import numpy as np
from itertools import product
from test.common import expand
-from numpy.testing import assert_allclose
-
-
-def _check_adjointness(op, dtype=np.float64):
- f1 = ift.Field.from_random("normal", domain=op.domain, dtype=dtype)
- f2 = ift.Field.from_random("normal", domain=op.target, dtype=dtype)
- cap = op.capability
- if ((cap & ift.LinearOperator.TIMES) and
- (cap & ift.LinearOperator.ADJOINT_TIMES)):
- assert_allclose(f1.vdot(op.adjoint_times(f2)),
- op.times(f1).vdot(f2),
- rtol=1e-8)
- if ((cap & ift.LinearOperator.INVERSE_TIMES) and
- (cap & ift.LinearOperator.INVERSE_ADJOINT_TIMES)):
- assert_allclose(f1.vdot(op.inverse_times(f2)),
- op.inverse_adjoint_times(f1).vdot(f2),
- rtol=1e-8)
_h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True),
@@ -49,29 +32,35 @@ _p_RG_spaces = [ift.RGSpace(19, distances=0.7),
_p_spaces = _p_RG_spaces + [ift.HPSpace(17), ift.GLSpace(8, 13)]
-class Adjointness_Tests(unittest.TestCase):
+class Consistency_Tests(unittest.TestCase):
@expand(product(_h_spaces, [np.float64, np.complex128]))
def testPPO(self, sp, dtype):
op = ift.PowerProjectionOperator(sp)
- _check_adjointness(op, dtype)
+ ift.extra.consistency_check(op, dtype, dtype)
ps = ift.PowerSpace(
sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=False, nbin=3))
op = ift.PowerProjectionOperator(sp, ps)
- _check_adjointness(op, dtype)
+ ift.extra.consistency_check(op, dtype, dtype)
ps = ift.PowerSpace(
sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=True, nbin=3))
op = ift.PowerProjectionOperator(sp, ps)
- _check_adjointness(op, dtype)
+ ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_RG_spaces+_p_RG_spaces,
[np.float64, np.complex128]))
def testFFT(self, sp, dtype):
op = ift.FFTOperator(sp)
- _check_adjointness(op, dtype)
+ ift.extra.consistency_check(op, dtype, dtype)
op = ift.FFTOperator(sp.get_default_codomain())
- _check_adjointness(op, dtype)
+ ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces, [np.float64, np.complex128]))
def testHarmonic(self, sp, dtype):
op = ift.HarmonicTransformOperator(sp)
- _check_adjointness(op, dtype)
+ ift.extra.consistency_check(op, dtype, dtype)
+
+ @expand(product(_h_spaces+_p_spaces, [np.float64, np.complex128]))
+ def testDiagonal(self, sp, dtype):
+ op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
+ dtype=dtype))
+ ift.extra.consistency_check(op, dtype, dtype)