diff --git a/nifty4/__init__.py b/nifty4/__init__.py index 65f44357a125b2db8b04c8e1d03de66368ae7fa6..55f9a9f9b3306d16cfdd72a66e083254ea1033c1 100644 --- a/nifty4/__init__.py +++ b/nifty4/__init__.py @@ -53,6 +53,7 @@ from .minimization.line_energy import LineEnergy from .sugar import * from .plotting.plot import plot from . import library +from . import extra __all__ = ["Domain", "UnstructuredDomain", "StructuredDomain", "RGSpace", "LMSpace", "HPSpace", "GLSpace", "DOFSpace", "PowerSpace", "DomainTuple", diff --git a/nifty4/extra/__init__.py b/nifty4/extra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0214b80c3664e07d16d584e029a5c8c3ad7ae1c --- /dev/null +++ b/nifty4/extra/__init__.py @@ -0,0 +1 @@ +from .operator_tests import consistency_check diff --git a/nifty4/extra/operator_tests.py b/nifty4/extra/operator_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..48ee48d6ce97b58bb96b15d31616e6df2b550947 --- /dev/null +++ b/nifty4/extra/operator_tests.py @@ -0,0 +1,65 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +import numpy as np +from ..field import Field +from .. import dobj + +__all__ = ["consistency_check"] + + +def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol): + needed_cap = op.TIMES | op.ADJOINT_TIMES + if (op.capability & needed_cap) != needed_cap: + return + f1 = Field.from_random("normal", op.domain, dtype=domain_dtype) + f2 = Field.from_random("normal", op.target, dtype=target_dtype) + res1 = f1.vdot(op.adjoint_times(f2)) + res2 = op.times(f1).vdot(f2) + np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) + + +def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): + needed_cap = op.TIMES | op.INVERSE_TIMES + if (op.capability & needed_cap) != needed_cap: + return + foo = Field.from_random("normal", op.target, dtype=target_dtype) + res = op(op.inverse_times(foo)) + np.testing.assert_allclose(dobj.to_global_data(res.val), + dobj.to_global_data(foo.val), + atol=atol, rtol=rtol) + + foo = Field.from_random("normal", op.domain, dtype=domain_dtype) + res = op.inverse_times(op(foo)) + np.testing.assert_allclose(dobj.to_global_data(res.val), + dobj.to_global_data(foo.val), + atol=atol, rtol=rtol) + + +def full_implementation(op, domain_dtype, target_dtype, atol, rtol): + adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol) + inverse_implementation(op, domain_dtype, target_dtype, atol, rtol) + + +def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, + atol=0, rtol=1e-7): + full_implementation(op, domain_dtype, target_dtype, atol, rtol) + full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol) + full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol) + full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol, + rtol) diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py index e923aa548fee6036bf38e75c89bab5aaf3565235..ec692d245c98aa3b59c1146760796ed85a55c99e 100644 --- a/nifty4/operators/linear_operator.py +++ b/nifty4/operators/linear_operator.py @@ -161,5 +161,4 @@ class LinearOperator(with_metaclass( self._check_mode(mode) if x.domain != self._dom(mode): - raise ValueError("The operator's and and field's domains " - "don't match.") + raise ValueError("The operator's and field's domains don't match.") diff --git a/test/test_energies/test_map.py b/test/test_energies/test_map.py index bac03da225f7387ca3dcda0d315add77f2f92a75..bbce2de88713c0b58b70500304b5b31708026aa0 100644 --- a/test/test_energies/test_map.py +++ b/test/test_energies/test_map.py @@ -259,8 +259,6 @@ class Curvature_Tests(unittest.TestCase): a = (gradient1 - gradient0) / eps b = energy0.curvature(direction) - print(a.vdot(a)) - print(b.vdot(b)) tol = 1e-7 assert_allclose(ift.dobj.to_global_data(a.val), ift.dobj.to_global_data(b.val), rtol=tol, atol=tol) diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index 9472dbce3f07c965608fbe37faf0ed38d0c93a7e..0579f857d261c3f7bf4cf63cf318137f92dec043 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -11,7 +11,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -# Copyright(C) 2013-2017 Max-Planck-Society +# Copyright(C) 2013-2018 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. @@ -21,23 +21,6 @@ import nifty4 as ift import numpy as np from itertools import product from test.common import expand -from numpy.testing import assert_allclose - - -def _check_adjointness(op, dtype=np.float64): - f1 = ift.Field.from_random("normal", domain=op.domain, dtype=dtype) - f2 = ift.Field.from_random("normal", domain=op.target, dtype=dtype) - cap = op.capability - if ((cap & ift.LinearOperator.TIMES) and - (cap & ift.LinearOperator.ADJOINT_TIMES)): - assert_allclose(f1.vdot(op.adjoint_times(f2)), - op.times(f1).vdot(f2), - rtol=1e-8) - if ((cap & ift.LinearOperator.INVERSE_TIMES) and - (cap & ift.LinearOperator.INVERSE_ADJOINT_TIMES)): - assert_allclose(f1.vdot(op.inverse_times(f2)), - op.inverse_adjoint_times(f1).vdot(f2), - rtol=1e-8) _h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True), @@ -49,29 +32,35 @@ _p_RG_spaces = [ift.RGSpace(19, distances=0.7), _p_spaces = _p_RG_spaces + [ift.HPSpace(17), ift.GLSpace(8, 13)] -class Adjointness_Tests(unittest.TestCase): +class Consistency_Tests(unittest.TestCase): @expand(product(_h_spaces, [np.float64, np.complex128])) def testPPO(self, sp, dtype): op = ift.PowerProjectionOperator(sp) - _check_adjointness(op, dtype) + ift.extra.consistency_check(op, dtype, dtype) ps = ift.PowerSpace( sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=False, nbin=3)) op = ift.PowerProjectionOperator(sp, ps) - _check_adjointness(op, dtype) + ift.extra.consistency_check(op, dtype, dtype) ps = ift.PowerSpace( sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=True, nbin=3)) op = ift.PowerProjectionOperator(sp, ps) - _check_adjointness(op, dtype) + ift.extra.consistency_check(op, dtype, dtype) @expand(product(_h_RG_spaces+_p_RG_spaces, [np.float64, np.complex128])) def testFFT(self, sp, dtype): op = ift.FFTOperator(sp) - _check_adjointness(op, dtype) + ift.extra.consistency_check(op, dtype, dtype) op = ift.FFTOperator(sp.get_default_codomain()) - _check_adjointness(op, dtype) + ift.extra.consistency_check(op, dtype, dtype) @expand(product(_h_spaces, [np.float64, np.complex128])) def testHarmonic(self, sp, dtype): op = ift.HarmonicTransformOperator(sp) - _check_adjointness(op, dtype) + ift.extra.consistency_check(op, dtype, dtype) + + @expand(product(_h_spaces+_p_spaces, [np.float64, np.complex128])) + def testDiagonal(self, sp, dtype): + op = ift.DiagonalOperator(ift.Field.from_random("normal", sp, + dtype=dtype)) + ift.extra.consistency_check(op, dtype, dtype)