Commit 1d25e062 authored by Reimar H Leike's avatar Reimar H Leike
Browse files

value-gradient consistency check now also works with models, added a test that...

value-gradient consistency check now also works with models, added a test that uses this as a test for the test
parent ae802b86
......@@ -18,15 +18,38 @@
import numpy as np
from ..sugar import from_random
from .operator_tests import _assert_allclose
from .. import Energy, Model
__all__ = ["check_value_gradient_consistency",
def _get_acceptable_model(M):
# TODO: do for model
val = M.value
if not np.isfinite(val.sum()):
raise ValueError('Initial Model value must be finite')
dir = from_random("normal", M.position.domain)
dirder = M.gradient(dir)
dir *= val/(dirder).norm()*1e-5
# find a step length that leads to a "reasonable" Model
for i in range(50):
M2 =
if np.isfinite(M2.value.sum()) and abs(M2.value.sum()) < 1e20:
except FloatingPointError:
dir *= 0.5
raise ValueError("could not find a reasonable initial step")
return M2
def _get_acceptable_energy(E):
val = E.value
if not np.isfinite(val):
raise ValueError
raise ValueError('Initial Energy must be finite')
dir = from_random("normal", E.position.domain)
dirder = E.gradient.vdot(dir)
dir *= np.abs(val)/np.abs(dirder)*1e-5
......@@ -46,14 +69,25 @@ def _get_acceptable_energy(E):
def check_value_gradient_consistency(E, tol=1e-8, ntries=100):
for _ in range(ntries):
if isinstance(E, Energy):
E2 = _get_acceptable_energy(E)
E2 = _get_acceptable_model(E)
val = E.value
dir = E2.position - E.position
# Enext = E2
dirnorm = dir.norm()
for i in range(50):
Emid = + 0.5*dir)
if isinstance(E, Energy):
dirder = Emid.gradient.vdot(dir)/dirnorm
dirder = Emid.gradient(dir)/dirnorm
if isinstance(E, Model):
#TODO: use relative error here instead
if ((E2.value-val)/dirnorm - dirder).norm()/dirder.norm()< tol:
xtol = tol*Emid.gradient_norm
if abs((E2.value-val)/dirnorm - dirder) < xtol:
......@@ -66,6 +100,8 @@ def check_value_gradient_consistency(E, tol=1e-8, ntries=100):
def check_value_gradient_curvature_consistency(E, tol=1e-8, ntries=100):
if isinstance(E, Model):
raise ValueError('Models have no curvature, thus it cannot be tested.')
for _ in range(ntries):
E2 = _get_acceptable_energy(E)
val = E.value
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2018 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
import nifty5 as ift
import numpy as np
from itertools import product
from test.common import expand
from numpy.testing import assert_allclose
class Model_Tests(unittest.TestCase):
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testMul(self, space, seed):
#TODO: use random Multifields instead
s1 = ift.full(space, np.random.randn())
s2 = ift.full(space, np.random.randn())
s1_var = ift.Variable(ift.MultiField({'s1':s1}))['s1']
s2_var = ift.Variable(ift.MultiField({'s2':s2}))['s2']
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment