test_linearization.py 2.21 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
import pytest
from numpy.testing import assert_, assert_allclose

Martin Reinecke's avatar
5->6  
Martin Reinecke committed
22
import nifty6 as ift
23
from .common import setup_function, teardown_function
Philipp Arras's avatar
Philipp Arras committed
24 25 26 27 28

pmp = pytest.mark.parametrize


def _lin2grad(lin):
Martin Reinecke's avatar
stage 3  
Martin Reinecke committed
29
    return lin.jac(ift.full(lin.domain, 1.)).val
Philipp Arras's avatar
Philipp Arras committed
30 31 32 33 34 35


def jt(lin, check):
    assert_allclose(_lin2grad(lin), check)


Philipp Arras's avatar
Philipp Arras committed
36
def test_special_gradients():
Philipp Arras's avatar
Philipp Arras committed
37 38 39
    dom = ift.UnstructuredDomain((1,))
    f = ift.full(dom, 2.4)
    var = ift.Linearization.make_var(f)
Martin Reinecke's avatar
stage 3  
Martin Reinecke committed
40
    s = f.val
Philipp Arras's avatar
Philipp Arras committed
41 42 43 44 45

    jt(var.clip(0, 10), np.ones_like(s))
    jt(var.clip(-1, 0), np.zeros_like(s))

    assert_allclose(
Martin Reinecke's avatar
Martin Reinecke committed
46 47
        _lin2grad(ift.Linearization.make_var(0*f).ptw("sinc")), np.zeros(s.shape))
    assert_(np.isnan(_lin2grad(ift.Linearization.make_var(0*f).ptw("abs"))))
Philipp Arras's avatar
Philipp Arras committed
48
    assert_allclose(
Martin Reinecke's avatar
Martin Reinecke committed
49
        _lin2grad(ift.Linearization.make_var(0*f + 10).ptw("abs")),
Philipp Arras's avatar
Philipp Arras committed
50 51
        np.ones(s.shape))
    assert_allclose(
Martin Reinecke's avatar
Martin Reinecke committed
52
        _lin2grad(ift.Linearization.make_var(0*f - 10).ptw("abs")),
Philipp Arras's avatar
Philipp Arras committed
53 54 55 56 57
        -np.ones(s.shape))


@pmp('f', [
    'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
Martin Reinecke's avatar
Martin Reinecke committed
58
    'absolute', 'reciprocal', 'sigmoid', 'log10', 'log1p', "expm1"
Philipp Arras's avatar
Philipp Arras committed
59 60 61 62 63 64 65
])
def test_actual_gradients(f):
    dom = ift.UnstructuredDomain((1,))
    fld = ift.full(dom, 2.4)
    eps = 1e-8
    var0 = ift.Linearization.make_var(fld)
    var1 = ift.Linearization.make_var(fld + eps)
Martin Reinecke's avatar
Martin Reinecke committed
66 67
    f0 = var0.ptw(f).val.val
    f1 = var1.ptw(f).val.val
Philipp Arras's avatar
Philipp Arras committed
68
    df0 = (f1 - f0)/eps
Martin Reinecke's avatar
Martin Reinecke committed
69
    df1 = _lin2grad(var0.ptw(f))
Philipp Arras's avatar
Philipp Arras committed
70
    assert_allclose(df0, df1, rtol=100*eps)