test_kl.py 2.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np

Martin Reinecke's avatar
5->6    
Martin Reinecke committed
20
import nifty6 as ift
21
22
from numpy.testing import assert_, assert_allclose
import pytest
23
from .common import setup_function, teardown_function
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

pmp = pytest.mark.parametrize


@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
def test_kl(constants, point_estimates, mirror_samples):
    dom = ift.RGSpace((12,), (2.12))
    op0 = ift.HarmonicSmoothingOperator(dom, 3)
    op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b'))
    lh = ift.GaussianEnergy(domain=op.target) @ op
    ic = ift.GradientNormController(iteration_limit=5)
    h = ift.StandardHamiltonian(lh, ic_samp=ic)
    mean0 = ift.from_random('normal', h.domain)

    nsamps = 2
    kl = ift.MetricGaussianKL(mean0,
                              h,
                              nsamps,
                              constants=constants,
                              point_estimates=point_estimates,
                              mirror_samples=mirror_samples,
                              napprox=0)
48
    samp_full = kl.samples
49
50
    klpure = ift.MetricGaussianKL(mean0,
                                  h,
51
52
                                  len(samp_full),
                                  mirror_samples=False,
53
                                  napprox=0,
54
                                  _samples=samp_full)
55
56
57
58
59
60

    # Test value
    assert_allclose(kl.value, klpure.value)

    # Test gradient
    for kk in h.domain.keys():
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
61
        res0 = klpure.gradient[kk].val
62
63
        if kk in constants:
            res0 = 0*res0
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
64
        res1 = kl.gradient[kk].val
65
66
67
68
69
70
71
72
73
        assert_allclose(res0, res1)

    # Test number of samples
    expected_nsamps = 2*nsamps if mirror_samples else nsamps
    assert_(len(kl.samples) == expected_nsamps)

    # Test point_estimates (after drawing samples)
    for kk in point_estimates:
        for ss in kl.samples:
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
74
            ss = ss[kk].val
75
76
77
78
79
80
            assert_allclose(ss, 0*ss)

    # Test constants (after some minimization)
    cg = ift.GradientNormController(iteration_limit=5)
    minimizer = ift.NewtonCG(cg)
    kl, _ = minimizer(kl)
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
81
    diff = (mean0 - kl.position).to_dict()
82
    for kk in constants:
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
83
        assert_allclose(diff[kk].val, 0*diff[kk].val)