test_model_gradients.py 5.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

import unittest
from itertools import product
from test.common import expand

import nifty5 as ift
import numpy as np


class Model_Tests(unittest.TestCase):
    @staticmethod
    def make_linearization(type, space, seed):
        np.random.seed(seed)
        S = ift.ScalingOperator(1., space)
        s = S.draw_sample()
        if type == "Constant":
            return ift.Linearization.make_const(s)
        elif type == "Variable":
            return ift.Linearization.make_var(s)
        raise ValueError('unknown type passed')

    @expand(product(
        [ift.GLSpace(15),
         ift.RGSpace(64, distances=.789),
         ift.RGSpace([32, 32], distances=.789)],
        [4, 78, 23]
        ))
    def testBasics(self, space, seed):
        var = self.make_linearization("Variable", space, seed)
47
        model = ift.ScalingOperator(6., var.target)
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        ift.extra.check_value_gradient_consistency(model, var.val)

    @expand(product(
        ['Variable', 'Constant'],
        ['Variable'],
        [ift.GLSpace(15),
         ift.RGSpace(64, distances=.789),
         ift.RGSpace([32, 32], distances=.789)],
        [4, 78, 23]
        ))
    def testBinary(self, type1, type2, space, seed):
        dom1 = ift.MultiDomain.make({'s1': space})
        lin1 = self.make_linearization(type1, dom1, seed)
        dom2 = ift.MultiDomain.make({'s2': space})
        lin2 = self.make_linearization(type2, dom2, seed)

        dom = ift.MultiDomain.union((dom1, dom2))
Philipp Arras's avatar
Philipp Arras committed
65
        model = ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")
66
        pos = ift.from_random("normal", dom)
Martin Reinecke's avatar
Martin Reinecke committed
67
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Philipp Arras's avatar
Philipp Arras committed
68
        model = ift.FieldAdapter(space, "s1")+ift.FieldAdapter(space, "s2")
69
        pos = ift.from_random("normal", dom)
Martin Reinecke's avatar
Martin Reinecke committed
70
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Philipp Arras's avatar
Philipp Arras committed
71
        model = ift.FieldAdapter(space, "s1").scale(3.)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
72
        pos = ift.from_random("normal", dom1)
Martin Reinecke's avatar
Martin Reinecke committed
73
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Martin Reinecke's avatar
Martin Reinecke committed
74
        model = ift.ScalingOperator(2.456, space)(
Philipp Arras's avatar
Philipp Arras committed
75
            ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
76
        pos = ift.from_random("normal", dom)
Martin Reinecke's avatar
Martin Reinecke committed
77
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Martin Reinecke's avatar
Martin Reinecke committed
78
        model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
Philipp Arras's avatar
Philipp Arras committed
79
            ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")))
80
        pos = ift.from_random("normal", dom)
Martin Reinecke's avatar
Martin Reinecke committed
81
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
Sebastian Hutschenreuter's avatar
Sebastian Hutschenreuter committed
82
        pos = ift.from_random("normal", dom)
Martin Reinecke's avatar
Martin Reinecke committed
83
84
        model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
        ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
85
        if isinstance(space, ift.RGSpace):
Martin Reinecke's avatar
Martin Reinecke committed
86
            model = ift.FFTOperator(space)(
Philipp Arras's avatar
Philipp Arras committed
87
                ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
88
            pos = ift.from_random("normal", dom)
Martin Reinecke's avatar
Martin Reinecke committed
89
            ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    @expand(product(
        [ift.GLSpace(15),
         ift.RGSpace(64, distances=.789),
         ift.RGSpace([32, 32], distances=.789)],
        [4],
        [0.5],
        [2.],
        [3.],
        [1.5],
        [1.75],
        [1.3],
        [4, 78, 23],
        ))
    def testModelLibrary(self, space, Npixdof, ceps_a,
                         ceps_k, sm, sv, im, iv, seed):
        # tests amplitude model and coorelated field model
        np.random.seed(seed)
        model = ift.AmplitudeModel(space, Npixdof, ceps_a, ceps_k, sm,
                                   sv, im, iv)
        S = ift.ScalingOperator(1., model.domain)
        pos = S.draw_sample()
Martin Reinecke's avatar
Martin Reinecke committed
112
        ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
113
114
115
116

        model2 = ift.CorrelatedField(space, model)
        S = ift.ScalingOperator(1., model2.domain)
        pos = S.draw_sample()
Martin Reinecke's avatar
Martin Reinecke committed
117
        ift.extra.check_value_gradient_consistency(model2, pos, ntries=20)
118

Martin Reinecke's avatar
Martin Reinecke committed
119
120
121
122
123
124
125
126
127
128
129
130
    @expand(product(
        [ift.GLSpace(15),
         ift.RGSpace(64, distances=.789),
         ift.RGSpace([32, 32], distances=.789)],
        [4, 78, 23]))
    def testPointModel(self, space, seed):
        S = ift.ScalingOperator(1., space)
        pos = S.draw_sample()
        alpha = 1.5
        q = 0.73
        model = ift.InverseGammaModel(space, alpha, q)
        # FIXME All those cdfs and ppfs are not very accurate
Martin Reinecke's avatar
Martin Reinecke committed
131
132
        ift.extra.check_value_gradient_consistency(model, pos, tol=1e-2,
                                                   ntries=20)
Martin Reinecke's avatar
Martin Reinecke committed
133

134
135
136
137
138
139
140
141
142
143
144
145
#     @expand(product(
#         ['Variable', 'Constant'],
#         [ift.GLSpace(15),
#          ift.RGSpace(64, distances=.789),
#          ift.RGSpace([32, 32], distances=.789)],
#         [4, 78, 23]
#         ))
#     def testMultiModel(self, type, space, seed):
#         model = self.make_model(
#             type, space_key='s', space=space, seed=seed)['s']
#         mmodel = ift.MultiModel(model, 'g')
#         ift.extra.check_value_gradient_consistency(mmodel)