test_energy_gradients.py 4.5 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
Philipp Arras's avatar
Philipp Arras committed
15
16
17
18
19
20
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
import pytest

Martin Reinecke's avatar
Martin Reinecke committed
21
import nifty7 as ift
Philipp Arras's avatar
Philipp Arras committed
22

Philipp Arras's avatar
Fixups    
Philipp Arras committed
23
from .common import list2fixture, setup_function, teardown_function
Philipp Arras's avatar
Philipp Arras committed
24

Philipp Arras's avatar
Fixups    
Philipp Arras committed
25
26
27
spaces = [ift.GLSpace(5),
          ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}),
          (ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))]
28
pmp = pytest.mark.parametrize
29
field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces])
Philipp Arras's avatar
Fixups    
Philipp Arras committed
30
ntries = 10
Philipp Arras's avatar
Philipp Arras committed
31

Philipp Arras's avatar
Philipp Arras committed
32

Philipp Arras's avatar
Philipp Arras committed
33
34
def test_gaussian(field):
    energy = ift.GaussianEnergy(domain=field.domain)
Philipp Arras's avatar
Philipp Arras committed
35
    ift.extra.check_operator(energy, field)
Philipp Arras's avatar
Philipp Arras committed
36

Philipp Arras's avatar
Philipp Arras committed
37

Philipp Arras's avatar
Philipp Arras committed
38
39
def test_ScaledEnergy(field):
    icov = ift.ScalingOperator(field.domain, 1.2)
Philipp Arras's avatar
Philipp Arras committed
40
    energy = ift.GaussianEnergy(inverse_covariance=icov, sampling_dtype=np.float64)
Philipp Arras's avatar
Philipp Arras committed
41
    ift.extra.check_operator(energy.scale(0.3), field)
42

Philipp Arras's avatar
Philipp Arras committed
43
    lin = ift.Linearization.make_var(field, want_metric=True)
Vincent Eberle's avatar
Vincent Eberle committed
44
    met1 = energy(lin).metric
Vincent Eberle's avatar
Vincent Eberle committed
45
    met2 = energy.scale(0.3)(lin).metric
Philipp Arras's avatar
Philipp Arras committed
46
47
48
    res1 = met1(field)
    res2 = met2(field)/0.3
    ift.extra.assert_allclose(res1, res2, 0, 1e-12)
Philipp Arras's avatar
Philipp Arras committed
49
50
    met1.draw_sample()
    met2.draw_sample()
Martin Reinecke's avatar
Martin Reinecke committed
51

Philipp Arras's avatar
Philipp Arras committed
52

Philipp Arras's avatar
Philipp Arras committed
53
54
def test_QuadraticFormOperator(field):
    op = ift.ScalingOperator(field.domain, 1.2)
Philipp Arras's avatar
Philipp Arras committed
55
    endo = ift.makeOp(op.draw_sample_with_dtype(dtype=np.float64))
Philipp Arras's avatar
Philipp Arras committed
56
    energy = ift.QuadraticFormOperator(endo)
Philipp Arras's avatar
Philipp Arras committed
57
    ift.extra.check_operator(energy, field)
Philipp Arras's avatar
Philipp Arras committed
58
59


60
def test_studentt(field):
Philipp Arras's avatar
Philipp Arras committed
61
62
    if isinstance(field.domain, ift.MultiDomain):
        return
63
    energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
64
    ift.extra.check_operator(energy, field)
65
    theta = ift.from_random(field.domain, 'normal').exp()
Reimar Leike's avatar
Fixup    
Reimar Leike committed
66
    energy = ift.StudentTEnergy(domain=field.domain, theta=theta)
67
    ift.extra.check_operator(energy, field, ntries=ntries)
Philipp Arras's avatar
Philipp Arras committed
68

Martin Reinecke's avatar
Martin Reinecke committed
69

Philipp Arras's avatar
Philipp Arras committed
70
def test_hamiltonian_and_KL(field):
Martin Reinecke's avatar
Martin Reinecke committed
71
    field = field.ptw("exp")
Philipp Arras's avatar
Philipp Arras committed
72
73
74
    space = field.domain
    lh = ift.GaussianEnergy(domain=space)
    hamiltonian = ift.StandardHamiltonian(lh)
Philipp Arras's avatar
Philipp Arras committed
75
    ift.extra.check_operator(hamiltonian, field, ntries=ntries)
76
    samps = [ift.from_random(space, 'normal') for i in range(2)]
Philipp Arras's avatar
Philipp Arras committed
77
    kl = ift.AveragedEnergy(hamiltonian, samps)
Philipp Arras's avatar
Philipp Arras committed
78
    ift.extra.check_operator(kl, field, ntries=ntries)
Philipp Arras's avatar
Philipp Arras committed
79
80
81
82
83


def test_variablecovariancegaussian(field):
    if isinstance(field.domain, ift.MultiDomain):
        return
Martin Reinecke's avatar
Martin Reinecke committed
84
    dc = {'a': field, 'b': field.ptw("exp")}
Philipp Arras's avatar
Philipp Arras committed
85
    mf = ift.MultiField.from_dict(dc)
Philipp Arras's avatar
Philipp Arras committed
86
    energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b', np.float64)
87
    ift.extra.check_operator(energy, mf, ntries=ntries)
Philipp Arras's avatar
Philipp Arras committed
88
    energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
Philipp Arras's avatar
Philipp Arras committed
89
90


91
92
93
94
95
def test_specialgamma(field):
    if isinstance(field.domain, ift.MultiDomain):
        return
    energy = ift.operators.energy_operators._SpecialGammaEnergy(field)
    loc = ift.from_random(energy.domain).exp()
96
    ift.extra.check_operator(energy, loc, ntries=ntries)
97
98
99
    energy(ift.Linearization.make_var(loc, want_metric=True)).metric.draw_sample()


Philipp Arras's avatar
Philipp Arras committed
100
def test_inverse_gamma(field):
Philipp Arras's avatar
Philipp Arras committed
101
102
    if isinstance(field.domain, ift.MultiDomain):
        return
Martin Reinecke's avatar
Martin Reinecke committed
103
    field = field.ptw("exp")
Philipp Arras's avatar
Philipp Arras committed
104
    space = field.domain
Martin Reinecke's avatar
Martin Reinecke committed
105
    d = ift.random.current_rng().normal(10, size=space.shape)**2
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
106
    d = ift.Field(space, d)
Philipp Arras's avatar
Philipp Arras committed
107
    energy = ift.InverseGammaLikelihood(d)
108
    ift.extra.check_operator(energy, field, tol=1e-10)
Philipp Arras's avatar
Philipp Arras committed
109
110
111


def testPoissonian(field):
Philipp Arras's avatar
Philipp Arras committed
112
113
    if isinstance(field.domain, ift.MultiDomain):
        return
Martin Reinecke's avatar
Martin Reinecke committed
114
    field = field.ptw("exp")
Philipp Arras's avatar
Philipp Arras committed
115
    space = field.domain
Martin Reinecke's avatar
Martin Reinecke committed
116
    d = ift.random.current_rng().poisson(120, size=space.shape)
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
117
    d = ift.Field(space, d)
Philipp Arras's avatar
Philipp Arras committed
118
    energy = ift.PoissonianEnergy(d)
119
    ift.extra.check_operator(energy, field)
Philipp Arras's avatar
Philipp Arras committed
120
121
122


def test_bernoulli(field):
Philipp Arras's avatar
Philipp Arras committed
123
124
    if isinstance(field.domain, ift.MultiDomain):
        return
Martin Reinecke's avatar
Martin Reinecke committed
125
    field = field.ptw("sigmoid")
Philipp Arras's avatar
Philipp Arras committed
126
    space = field.domain
Martin Reinecke's avatar
Martin Reinecke committed
127
    d = ift.random.current_rng().binomial(1, 0.1, size=space.shape)
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
128
    d = ift.Field(space, d)
Philipp Arras's avatar
Philipp Arras committed
129
    energy = ift.BernoulliEnergy(d)
130
    ift.extra.check_operator(energy, field, tol=1e-10)