test_model_gradients.py 5.44 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

import unittest
from itertools import product
from test.common import expand

import nifty5 as ift
import numpy as np


class Model_Tests(unittest.TestCase):
    @staticmethod
    def make_linearization(type, space, seed):
        np.random.seed(seed)
        S = ift.ScalingOperator(1., space)
        s = S.draw_sample()
        if type == "Constant":
            return ift.Linearization.make_const(s)
        elif type == "Variable":
            return ift.Linearization.make_var(s)
        raise ValueError('unknown type passed')

    @expand(product(
        [ift.GLSpace(15),
         ift.RGSpace(64, distances=.789),
         ift.RGSpace([32, 32], distances=.789)],
        [4, 78, 23]
        ))
    def testBasics(self, space, seed):
        var = self.make_linearization("Variable", space, seed)
47
        model = ift.ScalingOperator(6., var.target)
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        ift.extra.check_value_gradient_consistency(model, var.val)

    @expand(product(
        ['Variable', 'Constant'],
        ['Variable'],
        [ift.GLSpace(15),
         ift.RGSpace(64, distances=.789),
         ift.RGSpace([32, 32], distances=.789)],
        [4, 78, 23]
        ))
    def testBinary(self, type1, type2, space, seed):
        dom1 = ift.MultiDomain.make({'s1': space})
        lin1 = self.make_linearization(type1, dom1, seed)
        dom2 = ift.MultiDomain.make({'s2': space})
        lin2 = self.make_linearization(type2, dom2, seed)

        dom = ift.MultiDomain.union((dom1, dom2))
65
        model = ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")
66 67
        pos = ift.from_random("normal", dom)
        ift.extra.check_value_gradient_consistency(model, pos)
68
        model = ift.FieldAdapter(dom, "s1")+ift.FieldAdapter(dom, "s2")
69 70
        pos = ift.from_random("normal", dom)
        ift.extra.check_value_gradient_consistency(model, pos)
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
71
        model = ift.FieldAdapter(dom, "s1").scale(3.)
72
        pos = ift.from_random("normal", dom)
73
        ift.extra.check_value_gradient_consistency(model, pos)
Martin Reinecke's avatar
Martin Reinecke committed
74
        model = ift.ScalingOperator(2.456, space)(
75
            ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
76 77
        pos = ift.from_random("normal", dom)
        ift.extra.check_value_gradient_consistency(model, pos)
Martin Reinecke's avatar
Martin Reinecke committed
78
        model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
79
            ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")))
80 81 82
        pos = ift.from_random("normal", dom)
        ift.extra.check_value_gradient_consistency(model, pos)
        if isinstance(space, ift.RGSpace):
Martin Reinecke's avatar
Martin Reinecke committed
83
            model = ift.FFTOperator(space)(
84
                ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            pos = ift.from_random("normal", dom)
            ift.extra.check_value_gradient_consistency(model, pos)

    @expand(product(
        [ift.GLSpace(15),
         ift.RGSpace(64, distances=.789),
         ift.RGSpace([32, 32], distances=.789)],
        [4],
        [0.5],
        [2.],
        [3.],
        [1.5],
        [1.75],
        [1.3],
        [4, 78, 23],
        ))
    def testModelLibrary(self, space, Npixdof, ceps_a,
                         ceps_k, sm, sv, im, iv, seed):
        # tests amplitude model and coorelated field model
        np.random.seed(seed)
        model = ift.AmplitudeModel(space, Npixdof, ceps_a, ceps_k, sm,
                                   sv, im, iv)
        S = ift.ScalingOperator(1., model.domain)
        pos = S.draw_sample()
        ift.extra.check_value_gradient_consistency(model, pos)

        model2 = ift.CorrelatedField(space, model)
        S = ift.ScalingOperator(1., model2.domain)
        pos = S.draw_sample()
        ift.extra.check_value_gradient_consistency(model2, pos)

#     @expand(product(
#         [ift.GLSpace(15),
#          ift.RGSpace(64, distances=.789),
#          ift.RGSpace([32, 32], distances=.789)],
#         [4, 78, 23]))
#     def testPointModel(seld, space, seed):
#
#         S = ift.ScalingOperator(1., space)
#         pos = ift.MultiField.from_dict(
#                 {'points': S.draw_sample()})
#         alpha = 1.5
#         q = 0.73
Martin Reinecke's avatar
merge  
Martin Reinecke committed
128 129
#         model = ift.InverseGammaModel(pos, alpha, q, 'points')
#         # FIXME All those cdfs and ppfs are not very accurate
130 131 132 133 134 135 136 137 138 139 140 141 142 143
#         ift.extra.check_value_gradient_consistency(model, tol=1e-5)
#
#     @expand(product(
#         ['Variable', 'Constant'],
#         [ift.GLSpace(15),
#          ift.RGSpace(64, distances=.789),
#          ift.RGSpace([32, 32], distances=.789)],
#         [4, 78, 23]
#         ))
#     def testMultiModel(self, type, space, seed):
#         model = self.make_model(
#             type, space_key='s', space=space, seed=seed)['s']
#         mmodel = ift.MultiModel(model, 'g')
#         ift.extra.check_value_gradient_consistency(mmodel)