Skip to content
Snippets Groups Projects
Select Git revision
  • 6fcfd391366f49e51a8505bc536ef2e723cd4790
  • main default protected
  • wf_ve
  • sparse_mg
  • gpu_tests
  • mpi_samplelist_fix
  • pytorch_operator
  • qpo_model_rebased
  • native_extension
  • joint_re_cl_tests
  • re_fewer_tests
  • perf_tweaks
  • NIFTy_8 protected
  • fix_nonlinearity_gradients
  • cupy_backend
  • nifty
  • nifty8_philipps_unmerged_patches
  • nifty_jr
  • frequency_model
  • 423-minisanity-re-improve-likelihood-readability
  • 420-tracerboolconversion-error-in-lognormal_moments-py
  • 9.1.0 protected
  • 9.0.0 protected
  • v8.5.7 protected
  • v8.5.6 protected
  • v8.5.5 protected
  • v8.5.4 protected
  • v8.5.3 protected
  • v8.5.2 protected
  • v8.5.1 protected
  • v8.5 protected
  • v8.4 protected
  • v8.3 protected
  • v8.2 protected
  • v8.1 protected
  • v8.0 protected
  • v7.5 protected
  • v7.4 protected
  • v7.3 protected
  • v7.2 protected
  • v7.1 protected
41 results

extra.py

Blame
  • pfrank's avatar
    Philipp Frank authored
    4517b8e7
    History
    extra.py 20.12 KiB
    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 3 of the License, or
    # (at your option) any later version.
    #
    # This program is distributed in the hope that it will be useful,
    # but WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    # GNU General Public License for more details.
    #
    # You should have received a copy of the GNU General Public License
    # along with this program.  If not, see <http://www.gnu.org/licenses/>.
    #
    # Copyright(C) 2013-2021 Max-Planck-Society
    #
    # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
    
    from itertools import combinations
    
    import numpy as np
    
    from .domain_tuple import DomainTuple
    from .field import Field
    from .linearization import Linearization
    from .multi_domain import MultiDomain
    from .multi_field import MultiField
    from .operators.adder import Adder
    from .operators.endomorphic_operator import EndomorphicOperator
    from .operators.energy_operators import EnergyOperator
    from .operators.linear_operator import LinearOperator
    from .operators.operator import Operator
    from .operators.scaling_operator import ScalingOperator
    from .probing import StatCalculator
    from .sugar import from_random, full, is_fieldlike, is_operator
    from .utilities import myassert
    
    __all__ = ["check_linear_operator", "check_operator", "assert_allclose", "minisanity"]
    
    
    def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
                              atol=1e-12, rtol=1e-12, only_r_linear=False):
        """Checks an operator for algebraic consistency of its capabilities.
    
        Checks whether times(), adjoint_times(), inverse_times() and
        adjoint_inverse_times() (if in capability list) is implemented
        consistently. Additionally, it checks whether the operator is linear.
    
        Parameters
        ----------
        op : LinearOperator
            Operator which shall be checked.
        domain_dtype : dtype
            The data type of the random vectors in the operator's domain. Default
            is `np.float64`.
        target_dtype : dtype
            The data type of the random vectors in the operator's target. Default
            is `np.float64`.
        atol : float
            Absolute tolerance for the check. If rtol is specified,
            then satisfying any tolerance will let the check pass.
            Default: 0.
        rtol : float
            Relative tolerance for the check. If atol is specified,
            then satisfying any tolerance will let the check pass.
            Default: 0.
        only_r_linear: bool
            set to True if the operator is only R-linear, not C-linear.
            This will relax the adjointness test accordingly.
        """
        if not isinstance(op, LinearOperator):
            raise TypeError('This test tests only linear operators.')
        _domain_check_linear(op, domain_dtype)
        _domain_check_linear(op.adjoint, target_dtype)
        _domain_check_linear(op.inverse, target_dtype)
        _domain_check_linear(op.adjoint.inverse, domain_dtype)
        _purity_check(op, from_random(op.domain, dtype=domain_dtype))
        _purity_check(op.adjoint.inverse, from_random(op.domain, dtype=domain_dtype))
        _purity_check(op.adjoint, from_random(op.target, dtype=target_dtype))
        _purity_check(op.inverse, from_random(op.target, dtype=target_dtype))
        _check_linearity(op, domain_dtype, atol, rtol)
        _check_linearity(op.adjoint, target_dtype, atol, rtol)
        _check_linearity(op.inverse, target_dtype, atol, rtol)
        _check_linearity(op.adjoint.inverse, domain_dtype, atol, rtol)
        _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
                             only_r_linear)
        _full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol,
                             only_r_linear)
        _full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol,
                             only_r_linear)
        _full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
                             rtol, only_r_linear)
        _check_sqrt(op, domain_dtype)
        _check_sqrt(op.adjoint, target_dtype)
        _check_sqrt(op.inverse, target_dtype)
        _check_sqrt(op.adjoint.inverse, domain_dtype)
    
    
    def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
                       only_r_differentiable=True, metric_sampling=True):
        """Performs various checks of the implementation of linear and nonlinear
        operators.
    
        Computes the Jacobian with finite differences and compares it to the
        implemented Jacobian.
    
        Parameters
        ----------
        op : Operator
            Operator which shall be checked.
        loc : Field or MultiField
            An Field or MultiField instance which has the same domain
            as op. The location at which the gradient is checked
        tol : float
            Tolerance for the check.
        perf_check : Boolean
            Do performance check. May be disabled for very unimportant operators.
        only_r_differentiable : Boolean
            Jacobians of C-differentiable operators need to be C-linear.
            Default: True
        metric_sampling: Boolean
            If op is an EnergyOperator, metric_sampling determines whether the
            test shall try to sample from the metric or not.
        """
        if not isinstance(op, Operator):
            raise TypeError('This test tests only (nonlinear) operators.')
        _domain_check_nonlinear(op, loc)
        _purity_check(op, loc)
        _performance_check(op, loc, bool(perf_check))
        _linearization_value_consistency(op, loc)
        _jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries,
                                   only_r_differentiable)
        _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
                                   metric_sampling)
    
    
    def assert_allclose(f1, f2, atol=0, rtol=1e-7):
        if isinstance(f1, Field):
            return np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol)
        if f1.domain is not f2.domain:
            raise AssertionError
        for key, val in f1.items():
            assert_allclose(val, f2[key], atol=atol, rtol=rtol)
    
    
    def assert_equal(f1, f2):
        if isinstance(f1, Field):
            return np.testing.assert_equal(f1.val, f2.val)
        if f1.domain is not f2.domain:
            raise AssertionError
        for key, val in f1.items():
            assert_equal(val, f2[key])
    
    
    def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
                                only_r_linear):
        needed_cap = op.TIMES | op.ADJOINT_TIMES
        if (op.capability & needed_cap) != needed_cap:
            return
        f1 = from_random(op.domain, "normal", dtype=domain_dtype)
        f2 = from_random(op.target, "normal", dtype=target_dtype)
        res1 = f1.s_vdot(op.adjoint_times(f2))
        res2 = op.times(f1).s_vdot(f2)
        if only_r_linear:
            res1, res2 = res1.real, res2.real
        np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
    
    
    def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
        needed_cap = op.TIMES | op.INVERSE_TIMES
        if (op.capability & needed_cap) != needed_cap:
            return
        foo = from_random(op.target, "normal", dtype=target_dtype)
        res = op(op.inverse_times(foo))
        assert_allclose(res, foo, atol=atol, rtol=rtol)
    
        foo = from_random(op.domain, "normal", dtype=domain_dtype)
        res = op.inverse_times(op(foo))
        assert_allclose(res, foo, atol=atol, rtol=rtol)
    
    
    def _full_implementation(op, domain_dtype, target_dtype, atol, rtol,
                             only_r_linear):
        _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
                                only_r_linear)
        _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
    
    
    def _check_linearity(op, domain_dtype, atol, rtol):
        needed_cap = op.TIMES
        if (op.capability & needed_cap) != needed_cap:
            return
        fld1 = from_random(op.domain, "normal", dtype=domain_dtype)
        fld2 = from_random(op.domain, "normal", dtype=domain_dtype)
        alpha = 0.42
        val1 = op(alpha*fld1+fld2)
        val2 = alpha*op(fld1)+op(fld2)
        assert_allclose(val1, val2, atol=atol, rtol=rtol)
    
    
    def _domain_check_linear(op, domain_dtype=None, inp=None):
        _domain_check(op)
        needed_cap = op.TIMES
        if (op.capability & needed_cap) != needed_cap:
            return
        if domain_dtype is not None:
            inp = from_random(op.domain, "normal", dtype=domain_dtype)
        elif inp is None:
            raise ValueError('Need to specify either dtype or inp')
        myassert(inp.domain is op.domain)
        myassert(op(inp).domain is op.target)
    
    
    def _check_sqrt(op, domain_dtype):
        if not isinstance(op, EndomorphicOperator):
            try:
                op.get_sqrt()
                raise RuntimeError("Operator implements get_sqrt() although it is not an endomorphic operator.")
            except AttributeError:
                return
        try:
            sqop = op.get_sqrt()
        except (NotImplementedError, ValueError):
            return
        fld = from_random(op.domain, dtype=domain_dtype)
        a = op(fld)
        b = (sqop.adjoint @ sqop)(fld)
        return assert_allclose(a, b, rtol=1e-15)
    
    
    def _domain_check_nonlinear(op, loc):
        _domain_check(op)
        myassert(isinstance(loc, (Field, MultiField)))
        myassert(loc.domain is op.domain)
        for wm in [False, True]:
            lin = Linearization.make_var(loc, wm)
            reslin = op(lin)
            myassert(lin.domain is op.domain)
            myassert(lin.target is op.domain)
            myassert(lin.val.domain is lin.domain)
            myassert(reslin.domain is op.domain)
            myassert(reslin.target is op.target)
            myassert(reslin.val.domain is reslin.target)
            myassert(reslin.target is op.target)
            myassert(reslin.jac.domain is reslin.domain)
            myassert(reslin.jac.target is reslin.target)
            myassert(lin.want_metric == reslin.want_metric)
            _domain_check_linear(reslin.jac, inp=loc)
            _domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc))
            if reslin.metric is not None:
                myassert(reslin.metric.domain is reslin.metric.target)
                myassert(reslin.metric.domain is op.domain)
    
    
    def _domain_check(op):
        for dd in [op.domain, op.target]:
            if not isinstance(dd, (DomainTuple, MultiDomain)):
                raise TypeError(
                    'The domain and the target of an operator need to',
                    'be instances of either DomainTuple or MultiDomain.')
    
    
    def _performance_check(op, pos, raise_on_fail):
        class CountingOp(LinearOperator):
            def __init__(self, domain):
                from .sugar import makeDomain
                self._domain = self._target = makeDomain(domain)
                self._capability = self.TIMES | self.ADJOINT_TIMES
                self._count = 0
    
            def apply(self, x, mode):
                self._count += 1
                return x
    
            @property
            def count(self):
                return self._count
        for wm in [False, True]:
            cop = CountingOp(op.domain)
            myop = op @ cop
            myop(pos)
            cond = [cop.count != 1]
            lin = myop(Linearization.make_var(pos, wm))
            cond.append(cop.count != 2)
            lin.jac(pos)
            cond.append(cop.count != 3)
            lin.jac.adjoint(lin.val)
            cond.append(cop.count != 4)
            if lin.metric is not None:
                lin.metric(pos)
                cond.append(cop.count != 6)
            if any(cond):
                s = 'The operator has a performance problem (want_metric={}).'.format(wm)
                from .logger import logger
                logger.error(s)
                logger.info(cond)
                if raise_on_fail:
                    raise RuntimeError(s)
    
    
    def _purity_check(op, pos):
        if isinstance(op, LinearOperator) and (op.capability & op.TIMES) != op.TIMES:
            return
        res0 = op(pos)
        res1 = op(pos)
        assert_equal(res0, res1)
    
    
    def _get_acceptable_location(op, loc, lin):
        if not np.isfinite(lin.val.s_sum()):
            raise ValueError('Initial value must be finite')
        direction = from_random(loc.domain, dtype=loc.dtype)
        dirder = lin.jac(direction)
        if dirder.norm() == 0:
            direction = direction * (lin.val.norm() * 1e-5)
        else:
            direction = direction * (lin.val.norm() * 1e-5 / dirder.norm())
        # Find a step length that leads to a "reasonable" location
        for i in range(50):
            try:
                loc2 = loc + direction
                lin2 = op(Linearization.make_var(loc2, lin.want_metric))
                if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20:
                    break
            except FloatingPointError:
                pass
            direction = direction * 0.5
        else:
            raise ValueError("could not find a reasonable initial step")
        return loc2, lin2
    
    
    def _linearization_value_consistency(op, loc):
        for wm in [False, True]:
            lin = Linearization.make_var(loc, wm)
            fld0 = op(loc)
            fld1 = op(lin).val
            assert_allclose(fld0, fld1, 0, 1e-7)
    
    
    def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
                                   metric_sampling):
        if isinstance(op.domain, DomainTuple):
            return
        keys = op.domain.keys()
        combis = []
        if len(keys) > 4:
            from .logger import logger
            logger.warning('Operator domain has more than 4 keys.')
            logger.warning('Check derivatives only with one constant key at a time.')
            combis = [[kk] for kk in keys]
        else:
            for ll in range(1, len(keys)):
                combis.extend(list(combinations(keys, ll)))
        for cstkeys in combis:
            varkeys = set(keys) - set(cstkeys)
            cstloc = loc.extract_by_keys(cstkeys)
            varloc = loc.extract_by_keys(varkeys)
    
            val0 = op(loc)
            _, op0 = op.simplify_for_constant_input(cstloc)
            myassert(op0.domain is varloc.domain)
            val1 = op0(varloc)
            assert_equal(val0, val1)
    
            lin = Linearization.make_partial_var(loc, cstkeys, want_metric=True)
            lin0 = Linearization.make_var(varloc, want_metric=True)
            oplin0 = op0(lin0)
            oplin = op(lin)
    
            myassert(oplin.jac.target is oplin0.jac.target)
            rndinp = from_random(oplin.jac.target)
            assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain),
                            oplin0.jac.adjoint(rndinp), 1e-13, 1e-13)
            foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
            assert_equal(foo, 0*foo)
    
            if isinstance(op, EnergyOperator) and metric_sampling:
                oplin.metric.draw_sample()
    
            # _jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries,
            #                            only_r_differentiable)
    
    
    def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
        for _ in range(ntries):
            lin = op(Linearization.make_var(loc))
            loc2, lin2 = _get_acceptable_location(op, loc, lin)
            direction = loc2 - loc
            locnext = loc2
            dirnorm = direction.norm()
            hist = []
            for i in range(50):
                locmid = loc + 0.5 * direction
                linmid = op(Linearization.make_var(locmid))
                dirder = linmid.jac(direction)
                numgrad = (lin2.val - lin.val)
                xtol = tol * dirder.norm() / np.sqrt(dirder.size)
                hist.append((numgrad - dirder).norm())
                # print(len(hist),hist[-1])
                if (abs(numgrad - dirder) <= xtol).s_all():
                    break
                direction = direction * 0.5
                dirnorm *= 0.5
                loc2, lin2 = locmid, linmid
            else:
                print(hist)
                raise ValueError("gradient and value seem inconsistent")
            loc = locnext
            check_linear_operator(linmid.jac, domain_dtype=loc.dtype,
                                  target_dtype=dirder.dtype,
                                  only_r_linear=only_r_differentiable,
                                  atol=tol**2, rtol=tol**2)
    
    
    def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
        """Log information about the current fit quality and prior compatibility.
    
        Log a table with fitting information for the likelihood and the prior.
        Assume that the variables in `energy.position.domain` are standard-normal
        distributed a priori. The table contains the reduced chi^2 value, the mean
        and the number of degrees of freedom for every key of a `MultiDomain`. If
        the domain is a `DomainTuple`, the displayed key is `<None>`.
    
        If everything is consistent the reduced chi^2 values should be close to one
        and the mean of the data residuals close to zero. If the reduced chi^2 value
        in latent space is significantly bigger than one and only one degree of
        freedom is present, the mean column gives an indication in which direction
        to change the respective hyper parameters.
    
        Ignore all NaN entries in the target of `modeldata_operator` and in `data`.
        Print reduced chi-square values above 2 and 5 in orange and red,
        respectively.
    
        Parameters
        ----------
        data : Field or MultiField
            Data which is subtracted from the output of `model_data`.
    
        metric_at_pos : function
            Function which takes a `Field` or `MultiField` in the domain of `mean`
            and returns an endomorphic operator which applies the inverse of the
            noise covariance in the domain of `data`.
    
        model_data : Operator
            Operator which generates model data.
    
        mean : Field or MultiField
            Mean of input of `model_data`.
    
        samples : iterable of Field or MultiField, optional
            Residual samples around `mean`. Default: no samples.
    
        Note
        ----
        For computing the reduced chi^2 values and the normalized residuals, the
        metric at `mean` is used.
    
        """
        from .logger import logger
        if not (
            is_operator(modeldata_operator)
            and is_fieldlike(data)
            and is_fieldlike(mean)
        ):
            raise TypeError
        keylen = 18
        for dom in [data.domain, mean.domain]:
            if isinstance(dom, MultiDomain):
                keylen = max([max(map(len, dom.keys())), keylen])
        keylen = min([keylen, 42])
        op0 = metric_at_pos(mean).get_sqrt() @ Adder(data, neg=True) @ modeldata_operator
        op1 = ScalingOperator(mean.domain, 1)
        if not isinstance(op0.target, MultiDomain):
            op0 = op0.ducktape_left("<None>")
        if not isinstance(op1.target, MultiDomain):
            op1 = op1.ducktape_left("<None>")
        s = [full(mean.domain, 0.0)] if samples is None else samples
        xop = op0, op1
        xkeys = op0.target.keys(), op1.target.keys()
        xredchisq, xscmean, xndof = 2*[None], 2*[None], 2*[None]
        for aa in [0, 1]:
            xredchisq[aa] = {kk: StatCalculator() for kk in xkeys[aa]}
            xscmean[aa] = {kk: StatCalculator() for kk in xkeys[aa]}
            xndof[aa] = {}
        for ii, ss in enumerate(s):
            for aa in [0, 1]:
                rr = xop[aa].force(mean.unite(ss))
                for kk in xkeys[aa]:
                    xredchisq[aa][kk].add(np.nansum(abs(rr[kk].val) ** 2) / rr[kk].size)
                    xscmean[aa][kk].add(np.nanmean(rr[kk].val))
                    xndof[aa][kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val))
    
        s0 = _tableentries(xredchisq[0], xscmean[0], xndof[0], keylen)
        s1 = _tableentries(xredchisq[1], xscmean[1], xndof[1], keylen)
    
        f = logger.info
        n = 38 + keylen
        f(n * "=")
        f(
            (keylen + 2) * " "
            + "{:>11}".format("reduced χ²")
            + "{:>14}".format("mean")
            + "{:>11}".format("# dof")
        )
        f(n * "-")
        f("Data residuals\n" + s0)
        f("Latent space\n" + s1)
        f(n * "=")
    
    
    class _bcolors:
        WARNING = "\033[33m"
        FAIL = "\033[31m"
        ENDC = "\033[0m"
        BOLD = "\033[1m"
    
    
    def _tableentries(redchisq, scmean, ndof, keylen):
        out = ""
        for kk in redchisq.keys():
            if len(kk) > keylen:
                out += "  " + kk[: keylen - 1] + ""
            else:
                out += "  " + kk.ljust(keylen)
            foo = f"{redchisq[kk].mean:.1f}"
            try:
                foo += f" ± {np.sqrt(redchisq[kk].var):.1f}"
            except RuntimeError:
                pass
            if redchisq[kk].mean > 5 or redchisq[kk].mean < 1/5:
                out += _bcolors.FAIL + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC
            elif redchisq[kk].mean > 2 or redchisq[kk].mean < 1/2:
                out += _bcolors.WARNING + _bcolors.BOLD + f"{foo:>11}" + _bcolors.ENDC
            else:
                out += f"{foo:>11}"
    
            foo = f"{scmean[kk].mean:.1f}"
            try:
                foo += f" ± {np.sqrt(scmean[kk].var):.1f}"
            except RuntimeError:
                pass
            out += f"{foo:>14}"
            out += f"{ndof[kk]:>11}"
            out += "\n"
        return out[:-1]