test_kl.py 3.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
19
import pytest
20
from numpy.testing import assert_allclose, assert_raises
Philipp Arras's avatar
Philipp Arras committed
21

Martin Reinecke's avatar
merge    
Martin Reinecke committed
22
import nifty7 as ift
23
from nifty7 import myassert
Philipp Arras's avatar
Philipp Arras committed
24

25
from .common import setup_function, teardown_function
26
27
28
29
30
31
32

pmp = pytest.mark.parametrize


@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
Philipp Arras's avatar
Philipp Arras committed
33
@pmp('mf', (True, False))
Philipp Frank's avatar
Philipp Frank committed
34
35
@pmp('geo', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf, geo):
Philipp Arras's avatar
Philipp Arras committed
36
37
    if not mf and (len(point_estimates) != 0 or len(constants) != 0):
        return
38
    dom = ift.RGSpace((12,), (2.12))
Philipp Arras's avatar
Philipp Arras committed
39
40
41
    op = ift.HarmonicSmoothingOperator(dom, 3)
    if mf:
        op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
Philipp Arras's avatar
Philipp Arras committed
42
    lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op
43
    ic = ift.GradientNormController(iteration_limit=5)
Philipp Arras's avatar
Philipp Arras committed
44
    ic.enable_logging()
45
    h = ift.StandardHamiltonian(lh, ic_samp=ic)
46
    mean0 = ift.from_random(h.domain, 'normal')
47
48

    nsamps = 2
Philipp Arras's avatar
Philipp Arras committed
49
50
51
52
53
54
    args = {'constants': constants,
            'point_estimates': point_estimates,
            'mirror_samples': mirror_samples,
            'n_samples': nsamps,
            'mean': mean0,
            'hamiltonian': h}
Philipp Frank's avatar
Philipp Frank committed
55
56
    if geo:
        args['minimizer_samp'] = ift.NewtonCG(ic)
Philipp Arras's avatar
Philipp Arras committed
57
58
    if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
        with assert_raises(RuntimeError):
Philipp Frank's avatar
Philipp Frank committed
59
60
61
62
            if geo:
                ift.GeoMetricKL(**args)
            else:
                ift.MetricGaussianKL(**args)
Philipp Arras's avatar
Philipp Arras committed
63
        return
Philipp Frank's avatar
Philipp Frank committed
64
65
66
67
    if geo:
        kl = ift.GeoMetricKL(**args)
    else:
        kl = ift.MetricGaussianKL(**args)
68
69
70
    myassert(len(ic.history) > 0)
    myassert(len(ic.history) == len(ic.history.time_stamps))
    myassert(len(ic.history) == len(ic.history.energy_values))
Philipp Arras's avatar
Philipp Arras committed
71
    ic.history.reset()
72
73
74
    myassert(len(ic.history) == 0)
    myassert(len(ic.history) == len(ic.history.time_stamps))
    myassert(len(ic.history) == len(ic.history.energy_values))
Philipp Arras's avatar
Philipp Arras committed
75

76
    locsamp = kl._local_samples
77
78
    if isinstance(mean0, ift.MultiField):
        _, tmph = h.simplify_for_constant_input(mean0.extract_by_keys(constants))
79
        tmpmean = mean0.extract(tmph.domain)
80
81
    else:
        tmph = h
82
        tmpmean = mean0
Philipp Frank's avatar
Philipp Frank committed
83
84
85
86
    if geo and mirror_samples:
        klpure = ift.minimization.kl_energies._SampledKLEnergy(tmpmean, tmph, 2*nsamps, False, None, locsamp, False)
    else:
        klpure = ift.minimization.kl_energies._SampledKLEnergy(tmpmean, tmph, nsamps, mirror_samples, None, locsamp, False)
Philipp Arras's avatar
Philipp Arras committed
87
88
    # Test number of samples
    expected_nsamps = 2*nsamps if mirror_samples else nsamps
89
    myassert(len(tuple(kl.samples)) == expected_nsamps)
Philipp Arras's avatar
Philipp Arras committed
90

91
92
93
94
    # Test value
    assert_allclose(kl.value, klpure.value)

    # Test gradient
Philipp Arras's avatar
Philipp Arras committed
95
96
97
98
    if not mf:
        ift.extra.assert_allclose(kl.gradient, klpure.gradient, 0, 1e-14)
        return

99
    for kk in kl.position.domain.keys():
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
100
        res1 = kl.gradient[kk].val
101
102
103
104
        if kk in constants:
            res0 = 0*res1
        else:
            res0 = klpure.gradient[kk].val
105
        assert_allclose(res0, res1)