Commit 11e62c4e authored by Philipp Arras's avatar Philipp Arras
Browse files

Add performance check to jacobian consistency

parent 128309f6
......@@ -123,6 +123,37 @@ def _domain_check(op):
'be instances of either DomainTuple or MultiDomain.')
def _performance_check(op, pos, raise_on_fail):
class CountingOp(LinearOperator):
def __init__(self, domain):
from .sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._count = 0
def apply(self, x, mode):
self._count += 1
return x
@property
def count(self):
return self._count
cop = CountingOp(op.domain)
op = op @ cop
lin = op(Linearization.make_var(pos))
cond = [cop.count != 1]
lin.jac(pos)
cond.append(cop.count != 2)
lin.jac.adjoint(lin.val)
cond.append(cop.count != 3)
if any(cond):
s = 'The operator has a performance problem.'
if raise_on_fail:
raise RuntimeError(s)
from .logger import logger
logger.warn(s)
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7, only_r_linear=False):
"""
......@@ -219,6 +250,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
"""
_domain_check(op)
_actual_domain_check_nonlinear(op, loc)
_performance_check(op, loc, True) # FIXME Make the default False
for _ in range(ntries):
lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment