test_gradients.py 6.2 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
import pytest
20
from numpy.testing import assert_
Philipp Arras's avatar
Philipp Arras committed
21 22 23

import nifty5 as ift

24
from ..common import list2fixture
Philipp Arras's avatar
Philipp Arras committed
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

pmp = pytest.mark.parametrize
space = list2fixture([
    ift.GLSpace(15),
    ift.RGSpace(64, distances=.789),
    ift.RGSpace([32, 32], distances=.789)
])
space1 = space
seed = list2fixture([4, 78, 23])


def _make_linearization(type, space, seed):
    np.random.seed(seed)
    S = ift.ScalingOperator(1., space)
    s = S.draw_sample()
    if type == "Constant":
        return ift.Linearization.make_const(s)
    elif type == "Variable":
        return ift.Linearization.make_var(s)
    raise ValueError('unknown type passed')


def testBasics(space, seed):
    var = _make_linearization("Variable", space, seed)
    model = ift.ScalingOperator(6., var.target)
    ift.extra.check_value_gradient_consistency(model, var.val)


@pmp('type1', ['Variable', 'Constant'])
@pmp('type2', ['Variable'])
def testBinary(type1, type2, space, seed):
    dom1 = ift.MultiDomain.make({'s1': space})
    dom2 = ift.MultiDomain.make({'s2': space})
58 59 60 61

    # FIXME Remove this?
    _make_linearization(type1, dom1, seed)
    _make_linearization(type2, dom2, seed)
Philipp Arras's avatar
Philipp Arras committed
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

    dom = ift.MultiDomain.union((dom1, dom2))
    select_s1 = ift.ducktape(None, dom, "s1")
    select_s2 = ift.ducktape(None, dom, "s2")
    model = select_s1*select_s2
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    model = select_s1 + select_s2
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    model = select_s1.scale(3.)
    pos = ift.from_random("normal", dom1)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    model = ift.ScalingOperator(2.456, space)(select_s1*select_s2)
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
78
    model = ift.sigmoid(2.456*(select_s1*select_s2))
Philipp Arras's avatar
Philipp Arras committed
79 80 81 82 83
    pos = ift.from_random("normal", dom)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
    pos = ift.from_random("normal", dom)
    model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
    ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
Martin Reinecke's avatar
Martin Reinecke committed
84
    model = select_s1**2
85 86
    pos = ift.from_random("normal", dom1)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Martin Reinecke's avatar
Martin Reinecke committed
87
    model = select_s1.clip(-1, 1)
88 89
    pos = ift.from_random("normal", dom1)
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Philipp Arras's avatar
Philipp Arras committed
90 91 92 93 94 95 96 97 98
    if isinstance(space, ift.RGSpace):
        model = ift.FFTOperator(space)(select_s1*select_s2)
        pos = ift.from_random("normal", dom)
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)


def testModelLibrary(space, seed):
    # Tests amplitude model and coorelated field model
    np.random.seed(seed)
99
    domain = ift.PowerSpace(space.get_default_codomain())
100 101
    model = ift.SLAmplitude(target=domain, n_pix=4, a=.5, k0=2, sm=3, sv=1.5,
                            im=1.75, iv=1.3)
102
    assert_(isinstance(model, ift.Operator))
Philipp Arras's avatar
Philipp Arras committed
103 104 105 106 107 108 109 110 111
    S = ift.ScalingOperator(1., model.domain)
    pos = S.draw_sample()
    ift.extra.check_value_gradient_consistency(model, pos, ntries=20)

    model2 = ift.CorrelatedField(space, model)
    S = ift.ScalingOperator(1., model2.domain)
    pos = S.draw_sample()
    ift.extra.check_value_gradient_consistency(model2, pos, ntries=20)

112 113 114 115 116 117
    domtup = ift.DomainTuple.make((space, space))
    model3 = ift.MfCorrelatedField(domtup, [model, model])
    S = ift.ScalingOperator(1., model3.domain)
    pos = S.draw_sample()
    ift.extra.check_value_gradient_consistency(model3, pos, ntries=20)

Philipp Arras's avatar
Philipp Arras committed
118 119 120 121 122 123

def testPointModel(space, seed):
    S = ift.ScalingOperator(1., space)
    pos = S.draw_sample()
    alpha = 1.5
    q = 0.73
Philipp Arras's avatar
Fixups  
Philipp Arras committed
124
    model = ift.InverseGammaOperator(space, alpha, q)
Philipp Arras's avatar
Philipp Arras committed
125 126
    # FIXME All those cdfs and ppfs are not very accurate
    ift.extra.check_value_gradient_consistency(model, pos, tol=1e-2, ntries=20)
Martin Reinecke's avatar
Martin Reinecke committed
127

128

Philipp Frank's avatar
Philipp Frank committed
129
@pmp('target', [
Martin Reinecke's avatar
Martin Reinecke committed
130 131 132
    ift.RGSpace(64, distances=.789, harmonic=True),
    ift.RGSpace([32, 32], distances=.789, harmonic=True),
    ift.RGSpace([32, 32, 8], distances=.789, harmonic=True)
133
])
Martin Reinecke's avatar
Martin Reinecke committed
134 135 136
@pmp('causal', [True, False])
@pmp('minimum_phase', [True, False])
@pmp('seed', [4, 78, 23])
Philipp Frank's avatar
Philipp Frank committed
137 138 139 140 141 142 143 144 145 146 147
def testDynamicModel(target, causal, minimum_phase, seed):
    dct = {
            'target': target,
            'harmonic_padding': None,
            'sm_s0': 3.,
            'sm_x0': 1.,
            'key': 'f',
            'causal': causal,
            'minimum_phase': minimum_phase
            }
    model, _ = ift.dynamic_operator(**dct)
Martin Reinecke's avatar
Martin Reinecke committed
148 149 150
    S = ift.ScalingOperator(1., model.domain)
    pos = S.draw_sample()
    # FIXME I dont know why smaller tol fails for 3D example
151
    ift.extra.check_value_gradient_consistency(model, pos, tol=1e-5, ntries=20)
Philipp Frank's avatar
Philipp Frank committed
152
    if len(target.shape) > 1:
153
        dct = {
Philipp Frank's avatar
Philipp Frank committed
154
            'target': target,
155 156 157 158 159 160 161 162 163 164
            'harmonic_padding': None,
            'sm_s0': 3.,
            'sm_x0': 1.,
            'key': 'f',
            'lightcone_key': 'c',
            'sigc': 1.,
            'quant': 5,
            'causal': causal,
            'minimum_phase': minimum_phase
        }
Philipp Frank's avatar
Philipp Frank committed
165 166 167
        dct['lightcone_key'] = 'c'
        dct['sigc'] = 1.
        dct['quant'] = 5
168
        model, _ = ift.dynamic_lightcone_operator(**dct)
Martin Reinecke's avatar
Martin Reinecke committed
169 170 171
        S = ift.ScalingOperator(1., model.domain)
        pos = S.draw_sample()
        # FIXME I dont know why smaller tol fails for 3D example
172 173
        ift.extra.check_value_gradient_consistency(
            model, pos, tol=1e-5, ntries=20)