Commit 441e854e authored by Philipp Arras's avatar Philipp Arras
Browse files

Add more automatic checks for operators

parent 046d074c
Pipeline #63045 failed with stages
in 5 minutes and 53 seconds
......@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from numpy.testing import assert_
from .domain_tuple import DomainTuple
from .field import Field
from .linearization import Linearization
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.linear_operator import LinearOperator
from .sugar import from_random
......@@ -81,6 +83,38 @@ def _check_linearity(op, domain_dtype, atol, rtol):
_assert_allclose(val1, val2, atol=atol, rtol=rtol)
def _actual_domain_check(op, domain_dtype=None, inp=None):
needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap:
return
if domain_dtype is not None:
inp = from_random("normal", op.domain, dtype=domain_dtype)
elif inp is None:
raise ValueError('Need to specify either dtype or inp')
assert_(inp.domain is op.domain)
assert_(op(inp).domain is op.target)
def _actual_domain_check_nonlinear(op, loc):
assert isinstance(loc, (Field, MultiField))
assert_(loc.domain is op.domain)
lin = Linearization.make_var(loc, False)
reslin = op(lin)
assert_(lin.domain is op.domain)
assert_(lin.target is op.domain)
assert_(lin.val.domain is lin.domain)
assert_(reslin.domain is op.domain)
assert_(reslin.target is op.target)
assert_(reslin.val.domain is reslin.target)
assert_(reslin.target is op.target)
assert_(reslin.jac.domain is reslin.domain)
assert_(reslin.jac.target is reslin.target)
_actual_domain_check(reslin.jac, inp=loc)
_actual_domain_check(reslin.jac.adjoint, domain_dtype=np.float64)
def _domain_check(op):
for dd in [op.domain, op.target]:
if not isinstance(dd, (DomainTuple, MultiDomain)):
......@@ -123,6 +157,10 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.')
_domain_check(op)
_actual_domain_check(op, domain_dtype)
_actual_domain_check(op.adjoint, domain_dtype)
_actual_domain_check(op.inverse, domain_dtype)
_actual_domain_check(op.adjoint.inverse, domain_dtype)
_check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol)
......@@ -180,6 +218,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
Tolerance for the check.
"""
_domain_check(op)
_actual_domain_check_nonlinear(op, loc)
for _ in range(ntries):
lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment