# Copyright(C) 2013-2021 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. If not, see . # # Copyright(C) 2013-2021 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from itertools import combinations import numpy as np from .domain_tuple import DomainTuple from .field import Field from .linearization import Linearization from .multi_domain import MultiDomain from .multi_field import MultiField from .operators.adder import Adder from .operators.endomorphic_operator import EndomorphicOperator from .operators.energy_operators import EnergyOperator from .operators.linear_operator import LinearOperator from .operators.operator import Operator from .operators.scaling_operator import ScalingOperator from .probing import StatCalculator from .sugar import from_random, full, is_fieldlike, is_operator from .utilities import myassert __all__ = ["check_linear_operator", "check_operator", "assert_allclose", "minisanity"] def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64, atol=1e-12, rtol=1e-12, only_r_linear=False): """Checks an operator for algebraic consistency of its capabilities. Checks whether times(), adjoint_times(), inverse_times() and adjoint_inverse_times() (if in capability list) is implemented consistently. Additionally, it checks whether the operator is linear. Parameters ---------- op : LinearOperator Operator which shall be checked. domain_dtype : dtype The data type of the random vectors in the operator's domain. Default is `np.float64`. target_dtype : dtype The data type of the random vectors in the operator's target. Default is `np.float64`. atol : float Absolute tolerance for the check. If rtol is specified, then satisfying any tolerance will let the check pass. Default: 0. rtol : float Relative tolerance for the check. If atol is specified, then satisfying any tolerance will let the check pass. Default: 0. only_r_linear: bool set to True if the operator is only R-linear, not C-linear. This will relax the adjointness test accordingly. """ if not isinstance(op, LinearOperator): raise TypeError('This test tests only linear operators.') _domain_check_linear(op, domain_dtype) _domain_check_linear(op.adjoint, target_dtype) _domain_check_linear(op.inverse, target_dtype) _domain_check_linear(op.adjoint.inverse, domain_dtype) _purity_check(op, from_random(op.domain, dtype=domain_dtype)) _purity_check(op.adjoint.inverse, from_random(op.domain, dtype=domain_dtype)) _purity_check(op.adjoint, from_random(op.target, dtype=target_dtype)) _purity_check(op.inverse, from_random(op.target, dtype=target_dtype)) _check_linearity(op, domain_dtype, atol, rtol) _check_linearity(op.adjoint, target_dtype, atol, rtol) _check_linearity(op.inverse, target_dtype, atol, rtol) _check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol) _full_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear) _full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol, only_r_linear) _full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol, only_r_linear) _full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol, rtol, only_r_linear) _check_sqrt(op, domain_dtype) _check_sqrt(op.adjoint, target_dtype) _check_sqrt(op.inverse, target_dtype) _check_sqrt(op.adjoint.inverse, domain_dtype) def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True, only_r_differentiable=True, metric_sampling=True): """Performs various checks of the implementation of linear and nonlinear operators. Computes the Jacobian with finite differences and compares it to the implemented Jacobian. Parameters ---------- op : Operator Operator which shall be checked. loc : Field or MultiField An Field or MultiField instance which has the same domain as op. The location at which the gradient is checked tol : float Tolerance for the check. perf_check : Boolean Do performance check. May be disabled for very unimportant operators. only_r_differentiable : Boolean Jacobians of C-differentiable operators need to be C-linear. Default: True metric_sampling: Boolean If op is an EnergyOperator, metric_sampling determines whether the test shall try to sample from the metric or not. """ if not isinstance(op, Operator): raise TypeError('This test tests only (nonlinear) operators.') _domain_check_nonlinear(op, loc) _purity_check(op, loc) _performance_check(op, loc, bool(perf_check)) _linearization_value_consistency(op, loc) _jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries, only_r_differentiable) _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable, metric_sampling) def assert_allclose(f1, f2, atol=0, rtol=1e-7): if isinstance(f1, Field): return np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol) for key, val in f1.items(): assert_allclose(val, f2[key], atol=atol, rtol=rtol) def assert_equal(f1, f2): if isinstance(f1, Field): return np.testing.assert_equal(f1.val, f2.val) for key, val in f1.items(): assert_equal(val, f2[key]) def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear): needed_cap = op.TIMES | op.ADJOINT_TIMES if (op.capability & needed_cap) != needed_cap: return f1 = from_random(op.domain, "normal", dtype=domain_dtype) f2 = from_random(op.target, "normal", dtype=target_dtype) res1 = f1.s_vdot(op.adjoint_times(f2)) res2 = op.times(f1).s_vdot(f2) if only_r_linear: res1, res2 = res1.real, res2.real np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol): needed_cap = op.TIMES | op.INVERSE_TIMES if (op.capability & needed_cap) != needed_cap: return foo = from_random(op.target, "normal", dtype=target_dtype) res = op(op.inverse_times(foo)) assert_allclose(res, foo, atol=atol, rtol=rtol) foo = from_random(op.domain, "normal", dtype=domain_dtype) res = op.inverse_times(op(foo)) assert_allclose(res, foo, atol=atol, rtol=rtol) def _full_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear): _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol, only_r_linear) _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol) def _check_linearity(op, domain_dtype, atol, rtol): needed_cap = op.TIMES if (op.capability & needed_cap) != needed_cap: return fld1 = from_random(op.domain, "normal", dtype=domain_dtype) fld2 = from_random(op.domain, "normal", dtype=domain_dtype) alpha = 0.42 val1 = op(alpha*fld1+fld2) val2 = alpha*op(fld1)+op(fld2) assert_allclose(val1, val2, atol=atol, rtol=rtol) def _domain_check_linear(op, domain_dtype=None, inp=None): _domain_check(op) needed_cap = op.TIMES if (op.capability & needed_cap) != needed_cap: return if domain_dtype is not None: inp = from_random(op.domain, "normal", dtype=domain_dtype) elif inp is None: raise ValueError('Need to specify either dtype or inp') myassert(inp.domain is op.domain) myassert(op(inp).domain is op.target) def _check_sqrt(op, domain_dtype): if not isinstance(op, EndomorphicOperator): try: op.get_sqrt() raise RuntimeError("Operator implements get_sqrt() although it is not an endomorphic operator.") except AttributeError: return try: sqop = op.get_sqrt() except (NotImplementedError, ValueError): return fld = from_random(op.domain, dtype=domain_dtype) a = op(fld) b = (sqop.adjoint @ sqop)(fld) return assert_allclose(a, b, rtol=1e-15) def _domain_check_nonlinear(op, loc): _domain_check(op) myassert(isinstance(loc, (Field, MultiField))) myassert(loc.domain is op.domain) for wm in [False, True]: lin = Linearization.make_var(loc, wm) reslin = op(lin) myassert(lin.domain is op.domain) myassert(lin.target is op.domain) myassert(lin.val.domain is lin.domain) myassert(reslin.domain is op.domain) myassert(reslin.target is op.target) myassert(reslin.val.domain is reslin.target) myassert(reslin.target is op.target) myassert(reslin.jac.domain is reslin.domain) myassert(reslin.jac.target is reslin.target) myassert(lin.want_metric == reslin.want_metric) _domain_check_linear(reslin.jac, inp=loc) _domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc)) if reslin.metric is not None: myassert(reslin.metric.domain is reslin.metric.target) myassert(reslin.metric.domain is op.domain) def _domain_check(op): for dd in [op.domain, op.target]: if not isinstance(dd, (DomainTuple, MultiDomain)): raise TypeError( 'The domain and the target of an operator need to', 'be instances of either DomainTuple or MultiDomain.') def _performance_check(op, pos, raise_on_fail): class CountingOp(LinearOperator): def __init__(self, domain): from .sugar import makeDomain self._domain = self._target = makeDomain(domain) self._capability = self.TIMES | self.ADJOINT_TIMES self._count = 0 def apply(self, x, mode): self._count += 1 return x @property def count(self): return self._count for wm in [False, True]: cop = CountingOp(op.domain) myop = op @ cop myop(pos) cond = [cop.count != 1] lin = myop(Linearization.make_var(pos, wm)) cond.append(cop.count != 2) lin.jac(pos) cond.append(cop.count != 3) lin.jac.adjoint(lin.val) cond.append(cop.count != 4) if lin.metric is not None: lin.metric(pos) cond.append(cop.count != 6) if any(cond): s = 'The operator has a performance problem (want_metric={}).'.format(wm) from .logger import logger logger.error(s) logger.info(cond) if raise_on_fail: raise RuntimeError(s) def _purity_check(op, pos): if isinstance(op, LinearOperator) and (op.capability & op.TIMES) != op.TIMES: return res0 = op(pos) res1 = op(pos) assert_equal(res0, res1) def _get_acceptable_location(op, loc, lin): if not np.isfinite(lin.val.s_sum()): raise ValueError('Initial value must be finite') direction = from_random(loc.domain, dtype=loc.dtype) dirder = lin.jac(direction) if dirder.norm() == 0: direction = direction * (lin.val.norm() * 1e-5) else: direction = direction * (lin.val.norm() * 1e-5 / dirder.norm()) # Find a step length that leads to a "reasonable" location for i in range(50): try: loc2 = loc + direction lin2 = op(Linearization.make_var(loc2, lin.want_metric)) if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20: break except FloatingPointError: pass direction = direction * 0.5 else: raise ValueError("could not find a reasonable initial step") return loc2, lin2 def _linearization_value_consistency(op, loc): for wm in [False, True]: lin = Linearization.make_var(loc, wm) fld0 = op(loc) fld1 = op(lin).val assert_allclose(fld0, fld1, 0, 1e-7) def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable, metric_sampling): if isinstance(op.domain, DomainTuple): return keys = op.domain.keys() combis = [] if len(keys) > 4: from .logger import logger logger.warning('Operator domain has more than 4 keys.') logger.warning('Check derivatives only with one constant key at a time.') combis = [[kk] for kk in keys] else: for ll in range(1, len(keys)): combis.extend(list(combinations(keys, ll))) for cstkeys in combis: varkeys = set(keys) - set(cstkeys) cstloc = loc.extract_by_keys(cstkeys) varloc = loc.extract_by_keys(varkeys) val0 = op(loc) _, op0 = op.simplify_for_constant_input(cstloc) myassert(op0.domain is varloc.domain) val1 = op0(varloc) assert_equal(val0, val1) lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True) lin0 = Linearization.make_var(varloc, want_metric=True) oplin0 = op0(lin0) oplin = op(lin) myassert(oplin.jac.target is oplin0.jac.target) rndinp = from_random(oplin.jac.target, dtype=oplin.val.dtype) assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain), oplin0.jac.adjoint(rndinp), 1e-13, 1e-13) foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain) assert_equal(foo, 0*foo) if isinstance(op, EnergyOperator) and metric_sampling: oplin.metric.draw_sample() # _jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, # only_r_differentiable) def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable): for _ in range(ntries): lin = op(Linearization.make_var(loc)) loc2, lin2 = _get_acceptable_location(op, loc, lin) direction = loc2 - loc locnext = loc2 dirnorm = direction.norm() hist = [] for i in range(50): locmid = loc + 0.5 * direction linmid = op(Linearization.make_var(locmid)) dirder = linmid.jac(direction) numgrad = (lin2.val - lin.val) xtol = tol * dirder.norm() / np.sqrt(dirder.size) hist.append((numgrad - dirder).norm()) # print(len(hist),hist[-1]) if (abs(numgrad - dirder) <= xtol).s_all(): break direction = direction * 0.5 dirnorm *= 0.5 loc2, lin2 = locmid, linmid else: print(hist) raise ValueError("gradient and value seem inconsistent") loc = locnext check_linear_operator(linmid.jac, domain_dtype=loc.dtype, target_dtype=dirder.dtype, only_r_linear=only_r_differentiable, atol=tol**2, rtol=tol**2) def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None): """Log information about the current fit quality and prior compatibility. Log a table with fitting information for the likelihood and the prior. Assume that the variables in `energy.position.domain` are standard-normal distributed a priori. The table contains the reduced chi^2 value, the mean and the number of degrees of freedom for every key of a `MultiDomain`. If the domain is a `DomainTuple`, the displayed key is ``. If everything is consistent the reduced chi^2 values should be close to one and the mean of the data residuals close to zero. If the reduced chi^2 value in latent space is significantly bigger than one and only one degree of freedom is present, the mean column gives an indication in which direction to change the respective hyper parameters. Ignore all NaN entries in the target of `modeldata_operator` and in `data`. Print reduced chi-square values above 2 and 5 in orange and red, respectively. Parameters ---------- data : Field or MultiField Data which is subtracted from the output of `model_data`. metric_at_pos : function Function which takes a `Field` or `MultiField` in the domain of `mean` and returns an endomorphic operator which applies the inverse of the noise covariance in the domain of `data`. model_data : Operator Operator which generates model data. mean : Field or MultiField Mean of input of `model_data`. samples : iterable of Field or MultiField, optional Residual samples around `mean`. Default: no samples. Note ---- For computing the reduced chi^2 values and the normalized residuals, the metric at `mean` is used. """ from .logger import logger if not ( is_operator(modeldata_operator) and is_fieldlike(data) and is_fieldlike(mean) ): raise TypeError keylen = 18 for dom in [data.domain, mean.domain]: if isinstance(dom, MultiDomain): keylen = max([max(map(len, dom.keys())), keylen]) keylen = min([keylen, 42]) op0 = metric_at_pos(mean).get_sqrt() @ Adder(data, neg=True) @ modeldata_operator op1 = ScalingOperator(mean.domain, 1) if not isinstance(op0.target, MultiDomain): op0 = op0.ducktape_left("") if not isinstance(op1.target, MultiDomain): op1 = op1.ducktape_left("") s = [full(mean.domain, 0.0)] if samples is None else samples xop = op0, op1 xkeys = op0.target.keys(), op1.target.keys() xredchisq, xscmean, xndof = 2*[None], 2*[None], 2*[None] for aa in [0, 1]: xredchisq[aa] = {kk: StatCalculator() for kk in xkeys[aa]} xscmean[aa] = {kk: StatCalculator() for kk in xkeys[aa]} xndof[aa] = {} for ii, ss in enumerate(s): for aa in [0, 1]: rr = xop[aa].force(mean.unite(ss)) for kk in xkeys[aa]: xredchisq[aa][kk].add(np.nansum(abs(rr[kk].val) ** 2) / rr[kk].size) xscmean[aa][kk].add(np.nanmean(rr[kk].val)) xndof[aa][kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val)) s0 = _tableentries(xredchisq[0], xscmean[0], xndof[0], keylen) s1 = _tableentries(xredchisq[1], xscmean[1], xndof[1], keylen) f = logger.info n = 38 + keylen f(n * "=") f( (keylen + 2) * " " + "{:>11}".format("reduced χ²") + "{:>14}".format("mean") + "{:>11}".format("# dof") ) f(n * "-") f("Data residuals\n" + s0) f("Latent space\n" + s1) f(n * "=") class _bcolors: WARNING = "\033[33m" FAIL = "\033[31m" ENDC = "\033[0m" BOLD = "\033[1m" def _tableentries(redchisq, scmean, ndof, keylen): out = "" for kk in redchisq.keys(): if len(kk) > keylen: out += " " + kk[: keylen - 1] + "…" else: out += " " + kk.ljust(keylen) foo = f"{redchisq[kk].mean:.1f}" try: foo += f" ± {np.sqrt(redchisq[kk].var):.1f}" except RuntimeError: pass if redchisq[kk].mean > 5 or redchisq[kk].mean < 1/5: out += _bcolors.FAIL + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC elif redchisq[kk].mean > 2 or redchisq[kk].mean < 1/2: out += _bcolors.WARNING + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC else: out += f"{foo:>11}" foo = f"{scmean[kk].mean:.1f}" try: foo += f" ± {np.sqrt(scmean[kk].var):.1f}" except RuntimeError: pass out += f"{foo:>14}" out += f"{ndof[kk]:>11}" out += "\n" return out[:-1]