Commit 5dbc6d90 authored by Philipp Arras's avatar Philipp Arras Committed by Reimar H Leike

Cosmetics

parent e5803345
......@@ -18,12 +18,12 @@
import numpy as np
from ..sugar import from_random
from .operator_tests import _assert_allclose
from .. import Energy, Model
__all__ = ["check_value_gradient_consistency",
"check_value_gradient_curvature_consistency"]
def _get_acceptable_model(M):
# TODO: do for model
val = M.value
......@@ -32,7 +32,7 @@ def _get_acceptable_model(M):
dir = from_random("normal", M.position.domain)
dirder = M.gradient(dir)
dir *= val/(dirder).norm()*1e-5
# find a step length that leads to a "reasonable" Model
# Find a step length that leads to a "reasonable" Model
for i in range(50):
try:
M2 = M.at(M.position+dir)
......@@ -53,7 +53,7 @@ def _get_acceptable_energy(E):
dir = from_random("normal", E.position.domain)
dirder = E.gradient.vdot(dir)
dir *= np.abs(val)/np.abs(dirder)*1e-5
# find a step length that leads to a "reasonable" energy
# Find a step length that leads to a "reasonable" energy
for i in range(50):
try:
E2 = E.at(E.position+dir)
......@@ -84,7 +84,7 @@ def check_value_gradient_consistency(E, tol=1e-8, ntries=100):
else:
dirder = Emid.gradient(dir)/dirnorm
if isinstance(E, Model):
#TODO: use relative error here instead
# TODO: use relative error here instead
if ((E2.value-val)/dirnorm - dirder).norm()/dirder.norm()< tol:
break
else:
......
......@@ -17,12 +17,11 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
import nifty5 as ift
import numpy as np
from itertools import product
from test.common import expand
from numpy.testing import assert_allclose
import nifty5 as ift
import numpy as np
class Model_Tests(unittest.TestCase):
......@@ -32,9 +31,9 @@ class Model_Tests(unittest.TestCase):
[4, 78, 23]))
def testMul(self, space, seed):
np.random.seed(seed)
#TODO: use random Multifields instead
# TODO: use random Multifields instead
s1 = ift.full(space, np.random.randn())
s2 = ift.full(space, np.random.randn())
s1_var = ift.Variable(ift.MultiField({'s1':s1}))['s1']
s2_var = ift.Variable(ift.MultiField({'s2':s2}))['s2']
s1_var = ift.Variable(ift.MultiField({'s1': s1}))['s1']
s2_var = ift.Variable(ift.MultiField({'s2': s2}))['s2']
ift.extra.check_value_gradient_consistency(s1_var*s2_var)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment