Commit b7a10368 authored by Lukas Platz's avatar Lukas Platz
Browse files

extra.check_linear_operator: add option to test on positive input values only

parent b5e93218
Pipeline #97526 passed with stages
in 11 minutes and 52 seconds
......@@ -38,7 +38,8 @@ __all__ = ["check_linear_operator", "check_operator", "assert_allclose", "minisa
def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=1e-12, rtol=1e-12, only_r_linear=False):
atol=1e-12, rtol=1e-12, only_r_linear=False,
pos_dom=False, pos_tgt=False):
"""Checks an operator for algebraic consistency of its capabilities.
Checks whether times(), adjoint_times(), inverse_times() and
......@@ -73,18 +74,18 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
_domain_check_linear(op.adjoint, target_dtype)
_domain_check_linear(op.inverse, target_dtype)
_domain_check_linear(op.adjoint.inverse, domain_dtype)
_check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol)
_check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol)
_check_linearity(op, domain_dtype, atol, rtol, pos_dom)
_check_linearity(op.adjoint, target_dtype, atol, rtol, pos_dom)
_check_linearity(op.inverse, target_dtype, atol, rtol, pos_dom)
_check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol, pos_dom)
_full_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear)
only_r_linear, pos_dom, pos_tgt)
_full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol,
only_r_linear)
only_r_linear, pos_dom, pos_tgt)
_full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol,
only_r_linear)
only_r_linear, pos_dom, pos_tgt)
_full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
rtol, only_r_linear)
rtol, only_r_linear, pos_dom, pos_tgt)
_check_sqrt(op, domain_dtype)
_check_sqrt(op.adjoint, target_dtype)
_check_sqrt(op.inverse, target_dtype)
......@@ -143,12 +144,16 @@ def assert_equal(f1, f2):
def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear):
only_r_linear, pos_dom=False, pos_tgt=False):
needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap:
return
f1 = from_random(op.domain, "normal", dtype=domain_dtype)
f2 = from_random(op.target, "normal", dtype=target_dtype)
if pos_dom:
f1 = f1.exp()
if pos_tgt:
f2 = f2.exp()
res1 = f1.s_vdot(op.adjoint_times(f2))
res2 = op.times(f1).s_vdot(f2)
if only_r_linear:
......@@ -156,32 +161,41 @@ def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol,
pos_dom=False, pos_tgt=False):
needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap:
return
foo = from_random(op.target, "normal", dtype=target_dtype)
if pos_tgt:
foo = foo.exp()
res = op(op.inverse_times(foo))
assert_allclose(res, foo, atol=atol, rtol=rtol)
foo = from_random(op.domain, "normal", dtype=domain_dtype)
if pos_dom:
foo = foo.exp()
res = op.inverse_times(op(foo))
assert_allclose(res, foo, atol=atol, rtol=rtol)
def _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear):
only_r_linear, pos_dom=False, pos_tgt=False):
_adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear)
_inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
only_r_linear, pos_dom, pos_tgt)
_inverse_implementation(op, domain_dtype, target_dtype, atol, rtol,
pos_dom, pos_tgt)
def _check_linearity(op, domain_dtype, atol, rtol):
def _check_linearity(op, domain_dtype, atol, rtol, pos_dom=False):
needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap:
return
fld1 = from_random(op.domain, "normal", dtype=domain_dtype)
fld2 = from_random(op.domain, "normal", dtype=domain_dtype)
if pos_dom:
fld1 = fld1.exp()
fld2 = fld2.exp()
alpha = 0.42
val1 = op(alpha*fld1+fld2)
val2 = alpha*op(fld1)+op(fld2)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment