energy_and_model_tests.py 3.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
from __future__ import absolute_import, division, print_function
Philipp Arras's avatar
Philipp Arras committed
20

21
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
22
23

from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
24
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
25
from ..sugar import from_random
26

Martin Reinecke's avatar
Martin Reinecke committed
27
__all__ = ["check_value_gradient_consistency",
Martin Reinecke's avatar
Martin Reinecke committed
28
           "check_value_gradient_metric_consistency"]
29

Philipp Arras's avatar
Philipp Arras committed
30

Martin Reinecke's avatar
Martin Reinecke committed
31
def _get_acceptable_location(op, loc, lin):
Martin Reinecke's avatar
Martin Reinecke committed
32
    if not np.isfinite(lin.val.sum()):
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35
36
        raise ValueError('Initial value must be finite')
    dir = from_random("normal", loc.domain)
    dirder = lin.jac(dir)
    if dirder.norm() == 0:
Martin Reinecke's avatar
Martin Reinecke committed
37
        dir = dir * (lin.val.norm()*1e-5)
Martin Reinecke's avatar
Martin Reinecke committed
38
    else:
Martin Reinecke's avatar
Martin Reinecke committed
39
        dir = dir * (lin.val.norm()*1e-5/dirder.norm())
Martin Reinecke's avatar
Martin Reinecke committed
40
41
42
43
    # Find a step length that leads to a "reasonable" location
    for i in range(50):
        try:
            loc2 = loc+dir
44
            lin2 = op(Linearization.make_var(loc2, lin.want_metric))
Martin Reinecke's avatar
Martin Reinecke committed
45
46
47
48
49
50
51
52
53
            if np.isfinite(lin2.val.sum()) and abs(lin2.val.sum()) < 1e20:
                break
        except FloatingPointError:
            pass
        dir = dir*0.5
    else:
        raise ValueError("could not find a reasonable initial step")
    return loc2, lin2

Martin Reinecke's avatar
Martin Reinecke committed
54

Martin Reinecke's avatar
Martin Reinecke committed
55
def _check_consistency(op, loc, tol, ntries, do_metric):
Martin Reinecke's avatar
Martin Reinecke committed
56
    for _ in range(ntries):
57
        lin = op(Linearization.make_var(loc, do_metric))
Martin Reinecke's avatar
Martin Reinecke committed
58
        loc2, lin2 = _get_acceptable_location(op, loc, lin)
Martin Reinecke's avatar
Martin Reinecke committed
59
        dir = loc2-loc
Martin Reinecke's avatar
Martin Reinecke committed
60
61
62
63
        locnext = loc2
        dirnorm = dir.norm()
        for i in range(50):
            locmid = loc + 0.5*dir
64
            linmid = op(Linearization.make_var(locmid, do_metric))
Martin Reinecke's avatar
Martin Reinecke committed
65
66
            dirder = linmid.jac(dir)
            numgrad = (lin2.val-lin.val)
Martin Reinecke's avatar
Martin Reinecke committed
67
            xtol = tol * dirder.norm() / np.sqrt(dirder.size)
Martin Reinecke's avatar
Martin Reinecke committed
68
69
            cond = (abs(numgrad-dirder) <= xtol).all()
            if do_metric:
Martin Reinecke's avatar
Martin Reinecke committed
70
71
                dgrad = linmid.metric(dir)
                dgrad2 = (lin2.gradient-lin.gradient)
Martin Reinecke's avatar
Martin Reinecke committed
72
73
                cond = cond and (abs(dgrad-dgrad2) <= xtol).all()
            if cond:
Martin Reinecke's avatar
Martin Reinecke committed
74
75
76
                break
            dir = dir*0.5
            dirnorm *= 0.5
Martin Reinecke's avatar
Martin Reinecke committed
77
            loc2, lin2 = locmid, linmid
Martin Reinecke's avatar
Martin Reinecke committed
78
79
80
        else:
            raise ValueError("gradient and value seem inconsistent")
        loc = locnext
Martin Reinecke's avatar
Martin Reinecke committed
81
82


Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
86
def check_value_gradient_consistency(op, loc, tol=1e-8, ntries=100):
    _check_consistency(op, loc, tol, ntries, False)


Martin Reinecke's avatar
Martin Reinecke committed
87
def check_value_gradient_metric_consistency(op, loc, tol=1e-8, ntries=100):
Martin Reinecke's avatar
Martin Reinecke committed
88
    _check_consistency(op, loc, tol, ntries, True)