diff --git a/src/extra.py b/src/extra.py index 24966a5e89c9f833268279d787f8d426f092aa5a..699a1e5145efa036aa43ea48bd32602bcb3b186b 100644 --- a/src/extra.py +++ b/src/extra.py @@ -73,6 +73,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64, _domain_check_linear(op.adjoint, target_dtype) _domain_check_linear(op.inverse, target_dtype) _domain_check_linear(op.adjoint.inverse, domain_dtype) + _purity_check(op, from_random(op.domain, dtype=domain_dtype)) + _purity_check(op.adjoint.inverse, from_random(op.domain, dtype=domain_dtype)) + _purity_check(op.adjoint, from_random(op.target, dtype=target_dtype)) + _purity_check(op.inverse, from_random(op.target, dtype=target_dtype)) _check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op.adjoint, target_dtype, atol, rtol) _check_linearity(op.inverse, target_dtype, atol, rtol) @@ -120,6 +124,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True, if not isinstance(op, Operator): raise TypeError('This test tests only (nonlinear) operators.') _domain_check_nonlinear(op, loc) + _purity_check(op, loc) _performance_check(op, loc, bool(perf_check)) _linearization_value_consistency(op, loc) _jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries, @@ -288,6 +293,14 @@ def _performance_check(op, pos, raise_on_fail): raise RuntimeError(s) +def _purity_check(op, pos): + if isinstance(op, LinearOperator) and (op.capability & op.TIMES) != op.TIMES: + return + res0 = op(pos) + res1 = op(pos) + assert_equal(res0, res1) + + def _get_acceptable_location(op, loc, lin): if not np.isfinite(lin.val.s_sum()): raise ValueError('Initial value must be finite') diff --git a/test/test_extra.py b/test/test_extra.py new file mode 100644 index 0000000000000000000000000000000000000000..51f68c717b374a9b8b6b941d4f9863ec1a41e2b2 --- /dev/null +++ b/test/test_extra.py @@ -0,0 +1,64 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2021 Max-Planck-Society +# Author: Philipp Arras + +import numpy as np +import pytest + +import nifty7 as ift +from time import time + +from .common import list2fixture, setup_function, teardown_function + + +pmp = pytest.mark.parametrize + + +class NonPureOperator(ift.Operator): + def __init__(self, domain): + self._domain = self._target = ift.makeDomain(domain) + + def apply(self, x): + self._check_input(x) + return x*time() + + +class NonPureLinearOperator(ift.LinearOperator): + def __init__(self, domain, cap): + self._domain = self._target = ift.makeDomain(domain) + self._capability = cap + + def apply(self, x, mode): + self._check_input(x, mode) + return x*time() + + +@pmp("cap", [ift.LinearOperator.ADJOINT_TIMES, + ift.LinearOperator.INVERSE_TIMES | ift.LinearOperator.TIMES]) +@pmp("ddtype", [np.float64, np.complex128]) +@pmp("tdtype", [np.float64, np.complex128]) +def test_purity_check_linear(cap, ddtype, tdtype): + dom = ift.RGSpace(2) + op = NonPureLinearOperator(dom, cap) + with pytest.raises(AssertionError): + ift.extra.check_linear_operator(op, ddtype, tdtype) + + +@pmp("dtype", [np.float64, np.complex128]) +def test_purity_check(dtype): + dom = ift.RGSpace(2) + op = NonPureOperator(dom) + with pytest.raises(AssertionError): + ift.extra.check_operator(op, dtype)