Commit b1896312 authored by Philipp Arras's avatar Philipp Arras

Add linearity check to consistency_check

parent 1f1f1e1e
......@@ -19,6 +19,7 @@ import numpy as np
from .field import Field
from .linearization import Linearization
from .operators.linear_operator import LinearOperator
from .sugar import from_random
__all__ = ["consistency_check", "check_value_gradient_consistency",
......@@ -62,8 +63,20 @@ def _full_implementation(op, domain_dtype, target_dtype, atol, rtol):
_inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
def _check_linearity(op, domain_dtype, atol, rtol):
fld1 = from_random("normal", op.domain, dtype=domain_dtype)
fld2 = from_random("normal", op.domain, dtype=domain_dtype)
alpha = np.random.random()
val1 = op(alpha*fld1+fld2)
val2 = alpha*op(fld1)+op(fld2)
_assert_allclose(val1, val2, atol=atol, rtol=rtol)
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7):
if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.')
_check_linearity(op, domain_dtype, atol, rtol)
_full_implementation(op, domain_dtype, target_dtype, atol, rtol)
_full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol)
_full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment