Commit 890a6b13 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'r_linear' into 'NIFTy_5'

introduce relaxed adjointness tests for R-linear operators

See merge request !312
parents ec9e1f7f 958100c5
Pipeline #46730 passed with stages
in 8 minutes and 33 seconds
...@@ -33,7 +33,8 @@ def _assert_allclose(f1, f2, atol, rtol): ...@@ -33,7 +33,8 @@ def _assert_allclose(f1, f2, atol, rtol):
_assert_allclose(val, f2[key], atol=atol, rtol=rtol) _assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol): def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear):
needed_cap = op.TIMES | op.ADJOINT_TIMES needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap: if (op.capability & needed_cap) != needed_cap:
return return
...@@ -41,6 +42,8 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -41,6 +42,8 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
f2 = from_random("normal", op.target, dtype=target_dtype) f2 = from_random("normal", op.target, dtype=target_dtype)
res1 = f1.vdot(op.adjoint_times(f2)) res1 = f1.vdot(op.adjoint_times(f2))
res2 = op.times(f1).vdot(f2) res2 = op.times(f1).vdot(f2)
if only_r_linear:
res1, res2 = res1.real, res2.real
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
...@@ -57,8 +60,10 @@ def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): ...@@ -57,8 +60,10 @@ def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
_assert_allclose(res, foo, atol=atol, rtol=rtol) _assert_allclose(res, foo, atol=atol, rtol=rtol)
def _full_implementation(op, domain_dtype, target_dtype, atol, rtol): def _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
_adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol) only_r_linear):
_adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear)
_inverse_implementation(op, domain_dtype, target_dtype, atol, rtol) _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
...@@ -72,7 +77,7 @@ def _check_linearity(op, domain_dtype, atol, rtol): ...@@ -72,7 +77,7 @@ def _check_linearity(op, domain_dtype, atol, rtol):
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7): atol=0, rtol=1e-7, only_r_linear=False):
""" """
Checks an operator for algebraic consistency of its capabilities. Checks an operator for algebraic consistency of its capabilities.
...@@ -98,15 +103,21 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -98,15 +103,21 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
Relative tolerance for the check. If atol is specified, Relative tolerance for the check. If atol is specified,
then satisfying any tolerance will let the check pass. then satisfying any tolerance will let the check pass.
Default: 0. Default: 0.
only_r_linear: bool
set to True if the operator is only R-linear, not C-linear.
This will relax the adjointness test accordingly.
""" """
if not isinstance(op, LinearOperator): if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.') raise TypeError('This test tests only linear operators.')
_check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op, domain_dtype, atol, rtol)
_full_implementation(op, domain_dtype, target_dtype, atol, rtol) _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
_full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol) only_r_linear)
_full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol) _full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol,
only_r_linear)
_full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol,
only_r_linear)
_full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol, _full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
rtol) rtol, only_r_linear)
def _get_acceptable_location(op, loc, lin): def _get_acceptable_location(op, loc, lin):
......
...@@ -71,6 +71,20 @@ def testLinearInterpolator(): ...@@ -71,6 +71,20 @@ def testLinearInterpolator():
ift.extra.consistency_check(op) ift.extra.consistency_check(op)
@pmp('sp', _h_spaces + _p_spaces + _pow_spaces)
def testRealizer(sp):
op = ift.Realizer(sp)
ift.extra.consistency_check(op, np.complex128, np.float64,
only_r_linear=True)
@pmp('sp', _h_spaces + _p_spaces + _pow_spaces)
def testConjugationOperator(sp):
op = ift.ConjugationOperator(sp)
ift.extra.consistency_check(op, np.complex128, np.complex128,
only_r_linear=True)
@pmp('args', [(ift.RGSpace(10, harmonic=True), 4, 0), (ift.RGSpace( @pmp('args', [(ift.RGSpace(10, harmonic=True), 4, 0), (ift.RGSpace(
(24, 31), distances=(0.4, 2.34), harmonic=True), 3, 0), (24, 31), distances=(0.4, 2.34), harmonic=True), 3, 0),
(ift.LMSpace(4), 10, 0)]) (ift.LMSpace(4), 10, 0)])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment