Commit a778ee76 authored by Philipp Arras's avatar Philipp Arras

Better tests

parent a50ec021
......@@ -83,7 +83,7 @@ def _check_linearity(op, domain_dtype, atol, rtol):
assert_allclose(val1, val2, atol=atol, rtol=rtol)
def _actual_domain_check(op, domain_dtype=None, inp=None):
def _actual_domain_check_linear(op, domain_dtype=None, inp=None):
needed_cap = op.TIMES
if (op.capability & needed_cap) != needed_cap:
return
......@@ -98,21 +98,25 @@ def _actual_domain_check(op, domain_dtype=None, inp=None):
def _actual_domain_check_nonlinear(op, loc):
assert isinstance(loc, (Field, MultiField))
assert_(loc.domain is op.domain)
lin = Linearization.make_var(loc, False)
reslin = op(lin)
assert_(lin.domain is op.domain)
assert_(lin.target is op.domain)
assert_(lin.val.domain is lin.domain)
for wm in [False, True]:
lin = Linearization.make_var(loc, wm)
reslin = op(lin)
assert_(lin.domain is op.domain)
assert_(lin.target is op.domain)
assert_(lin.val.domain is lin.domain)
assert_(reslin.domain is op.domain)
assert_(reslin.target is op.target)
assert_(reslin.val.domain is reslin.target)
assert_(reslin.domain is op.domain)
assert_(reslin.target is op.target)
assert_(reslin.val.domain is reslin.target)
assert_(reslin.target is op.target)
assert_(reslin.jac.domain is reslin.domain)
assert_(reslin.jac.target is reslin.target)
_actual_domain_check(reslin.jac, inp=loc)
_actual_domain_check(reslin.jac.adjoint, inp=reslin.jac(loc))
assert_(reslin.target is op.target)
assert_(reslin.jac.domain is reslin.domain)
assert_(reslin.jac.target is reslin.target)
_actual_domain_check_linear(reslin.jac, inp=loc)
_actual_domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc))
if wm:
assert_(reslin.metric.domain is reslin.metric.target)
assert_(reslin.metric.domain is op.domain)
def _domain_check(op):
......@@ -140,16 +144,16 @@ def _performance_check(op, pos, raise_on_fail):
return self._count
for wm in [False, True]:
cop = CountingOp(op.domain)
op = op @ cop
op(pos)
myop = op @ cop
myop(pos)
cond = [cop.count != 1]
lin = op(2*Linearization.make_var(pos, wm))
lin = myop(2*Linearization.make_var(pos, wm))
cond.append(cop.count != 2)
lin.jac(pos)
cond.append(cop.count != 3)
lin.jac.adjoint(lin.val)
cond.append(cop.count != 4)
if wm and op.target is DomainTuple.scalar_domain():
if wm and myop.target is DomainTuple.scalar_domain():
lin.metric(pos)
cond.append(cop.count != 6)
if any(cond):
......@@ -195,10 +199,10 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.')
_domain_check(op)
_actual_domain_check(op, domain_dtype)
_actual_domain_check(op.adjoint, target_dtype)
_actual_domain_check(op.inverse, target_dtype)
_actual_domain_check(op.adjoint.inverse, domain_dtype)
_actual_domain_check_linear(op, domain_dtype)
_actual_domain_check_linear(op.adjoint, target_dtype)
_actual_domain_check_linear(op.inverse, target_dtype)
_actual_domain_check_linear(op.adjoint.inverse, domain_dtype)
_check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment