test_fisher_metric.py 5.37 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
import pytest

Philipp Arras's avatar
Philipp Arras committed
21
import nifty7 as ift
22 23 24 25

from ..common import list2fixture, setup_function, teardown_function

spaces = [ift.GLSpace(5),
26
          ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}),
27 28 29
          (ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))]
pmp = pytest.mark.parametrize
field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
Philipp Arras's avatar
Philipp Arras committed
30
                     [ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces])
31

32 33
dtype = list2fixture([np.float64,
                     np.complex128])
Reimar Leike's avatar
Reimar Leike committed
34
Nsamp = 2000
Reimar Leike's avatar
Reimar Leike committed
35
np.random.seed(42)
36

Philipp Arras's avatar
Philipp Arras committed
37

38 39 40 41 42 43
def _to_array(d):
    if isinstance(d, np.ndarray):
        return d
    assert isinstance(d, dict)
    return np.concatenate(list(d.values()))

Philipp Arras's avatar
Philipp Arras committed
44

45 46 47 48 49 50 51 52
def _complex2real(sp):
    tup = tuple([d for d in sp])
    rsp = ift.DomainTuple.make((ift.UnstructuredDomain(2),) + tup)
    rl = ift.DomainTupleFieldInserter(rsp, 0, (0,))
    im = ift.DomainTupleFieldInserter(rsp, 0, (1,))
    x = ift.ScalingOperator(sp, 1)
    return rl(x.real)+im(x.imag)

Philipp Arras's avatar
Philipp Arras committed
53

54 55 56 57 58 59 60 61
def test_complex2real():
    sp = ift.UnstructuredDomain(3)
    op = _complex2real(ift.makeDomain(sp))
    f = ift.from_random(op.domain, 'normal', dtype=np.complex128)
    assert np.all((f == op.adjoint_times(op(f))).val)
    assert op(f).dtype == np.float64
    f = ift.from_random(op.target, 'normal')
    assert np.all((f == op(op.adjoint_times(f))).val)
Philipp Arras's avatar
Philipp Arras committed
62 63


64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
def energy_tester(pos, get_noisy_data, energy_initializer):
    if isinstance(pos, ift.Field):
        if np.issubdtype(pos.dtype, np.complexfloating):
            op = _complex2real(pos.domain)
        else:
            op = ift.ScalingOperator(pos.domain, 1.)
    else:
        ops = []
        for k,dom in pos.domain.items():
            if np.issubdtype(pos[k].dtype, np.complexfloating):
                ops.append(_complex2real(dom).ducktape(k).ducktape_left(k))
            else:
                FA = ift.FieldAdapter(dom, k)
                ops.append(FA.adjoint @ FA)
        realizer = ift.utilities.my_sum(ops)
        from nifty7.operator_spectrum import _DomRemover
        flattener = _DomRemover(realizer.target)
        op = flattener @ realizer

83
    npos = op(pos)
Philipp Arras's avatar
Philipp Arras committed
84 85
    nget_noisy_data = lambda mean: get_noisy_data(op.adjoint_times(mean))
    nenergy_initializer = lambda mean: energy_initializer(mean) @ op.adjoint
86
    _actual_energy_tester(npos, nget_noisy_data, nenergy_initializer)
87

Philipp Arras's avatar
Philipp Arras committed
88

89
def _actual_energy_tester(pos, get_noisy_data, energy_initializer):
90 91 92 93 94 95 96 97 98 99 100 101 102 103
    domain = pos.domain
    test_vec = ift.from_random(domain, 'normal')
    results = []
    lin = ift.Linearization.make_var(pos)
    for i in range(Nsamp):
        data = get_noisy_data(pos)
        energy = energy_initializer(data)
        grad = energy(lin).jac.adjoint(ift.full(energy.target, 1.))
        results.append(_to_array((grad*grad.s_vdot(test_vec)).val))
    res = np.mean(np.array(results), axis=0)
    std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)
    energy = energy_initializer(data)
    lin = ift.Linearization.make_var(pos, want_metric=True)
    res2 = _to_array(energy(lin).metric(test_vec).val)
Reimar Leike's avatar
Reimar Leike committed
104
    np.testing.assert_allclose(res/std, res2/std, atol=6)
105

Philipp Arras's avatar
Philipp Arras committed
106

107 108 109 110
def test_GaussianEnergy(field):
    dtype = field.dtype
    icov = ift.from_random(field.domain, 'normal')**2
    icov = ift.makeOp(icov)
Philipp Arras's avatar
Philipp Arras committed
111 112
    get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype(
        from_inverse=True, dtype=dtype)
113
    E_init = lambda data: ift.GaussianEnergy(mean=data, inverse_covariance=icov)
114 115
    energy_tester(field, get_noisy_data, E_init)

Philipp Arras's avatar
Philipp Arras committed
116

117 118
def test_PoissonEnergy(field):
    if not isinstance(field, ift.Field):
119
        pytest.skip("MultiField Poisson energy  not supported")
120
    if np.iscomplexobj(field.val):
121
        pytest.skip("Poisson energy not defined for complex flux")
Philipp Arras's avatar
Philipp Arras committed
122 123 124
    get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val))
    # Make rate positive and high enough to avoid bad statistic
    lam = 10*(field**2).clip(0.1, None)
125
    E_init = lambda data: ift.PoissonianEnergy(data)
126
    energy_tester(lam, get_noisy_data, E_init)
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141

def test_VariableCovarianceGaussianEnergy(dtype):
    dom = ift.UnstructuredDomain(3)
    res = ift.from_random(dom, 'normal', dtype=dtype)
    ivar = ift.from_random(dom, 'normal')**2+4.
    mf = ift.MultiField.from_dict({'res':res, 'ivar':ivar})
    energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype)
    def get_noisy_data(mean):
        samp = ift.from_random(dom, 'normal', dtype)
        samp = samp/mean['ivar'].sqrt()
        return samp + mean['res']
    def E_init(data):
        adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True)
        return energy.partial_insert(adder)
    energy_tester(mf, get_noisy_data, E_init)