Skip to content
Snippets Groups Projects
Commit 11e62c4e authored by Philipp Arras's avatar Philipp Arras
Browse files

Add performance check to jacobian consistency

parent 128309f6
No related branches found
No related tags found
1 merge request!417Performance pa
......@@ -123,6 +123,37 @@ def _domain_check(op):
'be instances of either DomainTuple or MultiDomain.')
def _performance_check(op, pos, raise_on_fail):
class CountingOp(LinearOperator):
def __init__(self, domain):
from .sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._count = 0
def apply(self, x, mode):
self._count += 1
return x
@property
def count(self):
return self._count
cop = CountingOp(op.domain)
op = op @ cop
lin = op(Linearization.make_var(pos))
cond = [cop.count != 1]
lin.jac(pos)
cond.append(cop.count != 2)
lin.jac.adjoint(lin.val)
cond.append(cop.count != 3)
if any(cond):
s = 'The operator has a performance problem.'
if raise_on_fail:
raise RuntimeError(s)
from .logger import logger
logger.warn(s)
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7, only_r_linear=False):
"""
......@@ -219,6 +250,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
"""
_domain_check(op)
_actual_domain_check_nonlinear(op, loc)
_performance_check(op, loc, True) # FIXME Make the default False
for _ in range(ntries):
lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment