test_kl.py 3.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np

Martin Reinecke's avatar
5->6    
Martin Reinecke committed
20
import nifty6 as ift
21
22
from numpy.testing import assert_, assert_allclose
import pytest
23
from .common import setup_function, teardown_function
24
25
26
27
28
29
30

pmp = pytest.mark.parametrize


@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
Philipp Arras's avatar
Philipp Arras committed
31
32
@pmp('mf', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf):
33
    dom = ift.RGSpace((12,), (2.12))
Philipp Arras's avatar
Philipp Arras committed
34
35
36
    op = ift.HarmonicSmoothingOperator(dom, 3)
    if mf:
        op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
37
38
39
40
41
42
43
44
45
46
47
48
49
    lh = ift.GaussianEnergy(domain=op.target) @ op
    ic = ift.GradientNormController(iteration_limit=5)
    h = ift.StandardHamiltonian(lh, ic_samp=ic)
    mean0 = ift.from_random('normal', h.domain)

    nsamps = 2
    kl = ift.MetricGaussianKL(mean0,
                              h,
                              nsamps,
                              constants=constants,
                              point_estimates=point_estimates,
                              mirror_samples=mirror_samples,
                              napprox=0)
50
    locsamp = kl._local_samples
51
52
    klpure = ift.MetricGaussianKL(mean0,
                                  h,
53
54
                                  nsamps,
                                  mirror_samples=mirror_samples,
55
                                  napprox=0,
56
                                  _local_samples=locsamp)
57
58
59
60
61
62

    # Test value
    assert_allclose(kl.value, klpure.value)

    # Test gradient
    for kk in h.domain.keys():
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
63
        res0 = klpure.gradient[kk].val
64
65
        if kk in constants:
            res0 = 0*res0
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
66
        res1 = kl.gradient[kk].val
67
68
69
70
        assert_allclose(res0, res1)

    # Test number of samples
    expected_nsamps = 2*nsamps if mirror_samples else nsamps
71
    assert_(len(tuple(kl.samples)) == expected_nsamps)
72
73
74
75

    # Test point_estimates (after drawing samples)
    for kk in point_estimates:
        for ss in kl.samples:
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
76
            ss = ss[kk].val
77
78
79
80
81
82
            assert_allclose(ss, 0*ss)

    # Test constants (after some minimization)
    cg = ift.GradientNormController(iteration_limit=5)
    minimizer = ift.NewtonCG(cg)
    kl, _ = minimizer(kl)
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
83
    diff = (mean0 - kl.position).to_dict()
84
    for kk in constants:
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
85
        assert_allclose(diff[kk].val, 0*diff[kk].val)