Commit 3aeba77e authored by Philipp Arras's avatar Philipp Arras

Support operators with very big domains improvements

parent 8da1276e
Pipeline #77076 passed with stages
in 13 minutes and 19 seconds
......@@ -87,8 +87,7 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
only_r_differentiable=True, metric_sampling=True,
max_combinations=np.inf):
only_r_differentiable=True, metric_sampling=True):
"""
Performs various checks of the implementation of linear and nonlinear
operators.
......@@ -113,9 +112,6 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
metric_sampling: Boolean
If op is an EnergyOperator, metric_sampling determines whether the
test shall try to sample from the metric or not.
max_combinations : Integer
The maximum number of combinations of keys which shall be used for
checking partially constant operator and its derivative.
"""
if not isinstance(op, Operator):
raise TypeError('This test tests only linear operators.')
......@@ -125,7 +121,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
_jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries,
only_r_differentiable)
_check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
metric_sampling, max_combinations)
metric_sampling)
def assert_allclose(f1, f2, atol, rtol):
......@@ -318,22 +314,19 @@ def _linearization_value_consistency(op, loc):
def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
metric_sampling, max_combinations=np.inf):
metric_sampling):
if isinstance(op.domain, DomainTuple):
return
keys = op.domain.keys()
combis = []
if len(keys) > 15:
if len(keys) > 4:
from .logger import logger
logger.warning('Operator domain has more than 15 keys.')
logger.warning('Operator domain has more than 4 keys.')
logger.warning('Check derivatives only with one constant key at a time.')
combis = list(keys)
combis = [[kk] for kk in keys]
else:
for ll in range(1, len(keys)):
combis.extend(list(combinations(keys, ll)))
if len(combis) > max_combinations:
combis = random.current_rng().choice(combis, int(max_combinations),
replace=False)
for cstkeys in combis:
varkeys = set(keys) - set(cstkeys)
cstloc = loc.extract_by_keys(cstkeys)
......
......@@ -108,7 +108,5 @@ def testAmplitudesInvariants(sspace, N):
assert_(op.target[-1] == fsspace)
for ampl in fa.normalized_amplitudes:
ift.extra.check_operator(ampl, 0.1*ift.from_random(ampl.domain),
ntries=10, max_combinations=3)
ift.extra.check_operator(op, 0.1*ift.from_random(op.domain),
ntries=10, max_combinations=5)
ift.extra.check_operator(ampl, 0.1*ift.from_random(ampl.domain), ntries=10)
ift.extra.check_operator(op, 0.1*ift.from_random(op.domain), ntries=10)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment