Skip to content
Snippets Groups Projects
Commit f1709dd5 authored by Philipp Arras's avatar Philipp Arras
Browse files

Even more thorough automatic tests for (linear) operators

parent ae0f9525
No related branches found
No related tags found
2 merge requests!324New gridder (again!),!318switch from numpy.fft to pypocketfft
Pipeline #48150 failed
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
import numpy as np import numpy as np
from .domain_tuple import DomainTuple
from .field import Field from .field import Field
from .linearization import Linearization from .linearization import Linearization
from .multi_domain import MultiDomain
from .operators.linear_operator import LinearOperator from .operators.linear_operator import LinearOperator
from .sugar import from_random from .sugar import from_random
...@@ -76,6 +78,13 @@ def _check_linearity(op, domain_dtype, atol, rtol): ...@@ -76,6 +78,13 @@ def _check_linearity(op, domain_dtype, atol, rtol):
_assert_allclose(val1, val2, atol=atol, rtol=rtol) _assert_allclose(val1, val2, atol=atol, rtol=rtol)
def _domain_check(op):
for dd in [op.domain, op.target]:
if not isinstance(dd, (DomainTuple, MultiDomain)):
raise TypeError('The domain and the target of an operator need to',
'be instances of either DomainTuple or MultiDomain.')
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7, only_r_linear=False): atol=0, rtol=1e-7, only_r_linear=False):
""" """
...@@ -109,6 +118,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -109,6 +118,7 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
""" """
if not isinstance(op, LinearOperator): if not isinstance(op, LinearOperator):
raise TypeError('This test tests only linear operators.') raise TypeError('This test tests only linear operators.')
_domain_check(op)
_check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op, domain_dtype, atol, rtol)
_full_implementation(op, domain_dtype, target_dtype, atol, rtol, _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
only_r_linear) only_r_linear)
...@@ -162,6 +172,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100): ...@@ -162,6 +172,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
tol : float tol : float
Tolerance for the check. Tolerance for the check.
""" """
_domain_check(op)
for _ in range(ntries): for _ in range(ntries):
lin = op(Linearization.make_var(loc)) lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin) loc2, lin2 = _get_acceptable_location(op, loc, lin)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment