test_model_gradients.py 5.45 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
import pytest

import nifty5 as ift

from .common import list2fixture

pmp = pytest.mark.parametrize
space = list2fixture([
    ift.GLSpace(15),
    ift.RGSpace(64, distances=.789),
    ift.RGSpace([32, 32], distances=.789)
])
space1 = space
seed = list2fixture([4, 78, 23])


def _make_linearization(type, space, seed):
    np.random.seed(seed)
    S = ift.ScalingOperator(1., space)
    s = S.draw_sample()
    if type == "Constant":
        return ift.Linearization.make_const(s)
    elif type == "Variable":
        return ift.Linearization.make_var(s)
    raise ValueError('unknown type passed')


def testBasics(space, seed):
    var = _make_linearization("Variable", space, seed)
    model = ift.ScalingOperator(6., var.target)
    ift.extra.check_value_gradient_consistency(model, var.val)


@pmp('type1', ['Variable', 'Constant'])
@pmp('type2', ['Variable'])
def testBinary(type1, type2, space, seed):
    dom1 = ift.MultiDomain.make({'s1': space})
    # FIXME Remove?
    lin1 = _make_linearization(type1, dom1, seed)
    dom2 = ift.MultiDomain.make({'s2': space})
    # FIXME Remove?
    lin2 = _make_linearization(type2, dom2, seed)

    dom = ift.MultiDomain.union((dom1, dom2))
    select_s1 = ift.ducktape(None, dom, "s1")
    select_s2 = ift.ducktape(None, dom, "s2")
    model = select_s1*select_s2
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    model = select_s1 + select_s2
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    model = select_s1.scale(3.)
    pos = ift.from_random("normal", dom1)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    model = ift.ScalingOperator(2.456, space)(select_s1*select_s2)
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Philipp Arras's avatar
Fixups  
Philipp Arras committed
77
    model = ift.sigmoid(
Philipp Arras's avatar
Philipp Arras committed
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        ift.ScalingOperator(2.456, space)(select_s1*select_s2))
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    pos = ift.from_random("normal", dom)
    model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
    ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
    if isinstance(space, ift.RGSpace):
        model = ift.FFTOperator(space)(select_s1*select_s2)
        pos = ift.from_random("normal", dom)
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)


def testModelLibrary(space, seed):
    # Tests amplitude model and coorelated field model
    Npixdof, ceps_a, ceps_k, sm, sv, im, iv = 4, 0.5, 2., 3., 1.5, 1.75, 1.3
    np.random.seed(seed)
Martin Reinecke's avatar
Martin Reinecke committed
94 95
    model = ift.AmplitudeOperator(space, Npixdof, ceps_a, ceps_k, sm, sv, im,
                                  iv)
Philipp Arras's avatar
Philipp Arras committed
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    S = ift.ScalingOperator(1., model.domain)
    pos = S.draw_sample()
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)

    model2 = ift.CorrelatedField(space, model)
    S = ift.ScalingOperator(1., model2.domain)
    pos = S.draw_sample()
    ift.extra.check_value_gradient_consistency(model2, pos, ntries=20)


def testPointModel(space, seed):
    S = ift.ScalingOperator(1., space)
    pos = S.draw_sample()
    alpha = 1.5
    q = 0.73
Philipp Arras's avatar
Fixups  
Philipp Arras committed
111
    model = ift.InverseGammaOperator(space, alpha, q)
Philipp Arras's avatar
Philipp Arras committed
112 113
    # FIXME All those cdfs and ppfs are not very accurate
    ift.extra.check_value_gradient_consistency(model, pos, tol=1e-2, ntries=20)
Martin Reinecke's avatar
Martin Reinecke committed
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

@pmp('domain', [ift.RGSpace(64, distances=.789),
                ift.RGSpace([32, 32], distances=.789),
                ift.RGSpace([32, 32, 8], distances=.789)])
@pmp('causal', [True, False])
@pmp('minimum_phase', [True, False])
@pmp('seed', [4, 78, 23])
def testDynamicModel(domain, causal, minimum_phase, seed):
    model, _ = ift.dynamic_operator(domain,None,1.,1.,'f',
                                    causal = causal,
                                    minimum_phase = minimum_phase)
    S = ift.ScalingOperator(1., model.domain)
    pos = S.draw_sample()
    # FIXME I dont know why smaller tol fails for 3D example
    ift.extra.check_value_gradient_consistency(model, pos, tol=1e-5,
                                               ntries=20)
    if len(domain.shape) > 1:
        model, _ = ift.dynamic_lightcone_operator(domain,None,3.,1.,
                                                  'f','c',1.,5,
                                                  causal = causal,
                                                  minimum_phase = minimum_phase)
        S = ift.ScalingOperator(1., model.domain)
        pos = S.draw_sample()
        # FIXME I dont know why smaller tol fails for 3D example
        ift.extra.check_value_gradient_consistency(model, pos, tol=1e-5,
                                                   ntries=20)