Skip to content
Snippets Groups Projects
Commit f08e8ab5 authored by Reimar H Leike's avatar Reimar H Leike
Browse files

deleted metrixc test as it was not doing what it said

parent 94e8fc43
No related branches found
No related tags found
No related merge requests found
......@@ -134,24 +134,20 @@ def _get_acceptable_location(op, loc, lin):
return loc2, lin2
def _check_consistency(op, loc, tol, ntries, do_metric):
def _check_consistency(op, loc, tol, ntries):
for _ in range(ntries):
lin = op(Linearization.make_var(loc, do_metric))
lin = op(Linearization.make_var(loc))
loc2, lin2 = _get_acceptable_location(op, loc, lin)
dir = loc2-loc
locnext = loc2
dirnorm = dir.norm()
for i in range(50):
locmid = loc + 0.5*dir
linmid = op(Linearization.make_var(locmid, do_metric))
linmid = op(Linearization.make_var(locmid))
dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val)
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
cond = (abs(numgrad-dirder) <= xtol).all()
if do_metric:
dgrad = linmid.metric(dir)
dgrad2 = (lin2.gradient-lin.gradient)
cond = cond and (abs(dgrad-dgrad2) <= xtol).all()
if cond:
break
dir = dir*0.5
......@@ -185,9 +181,5 @@ def check_value_gradient_consistency(op, loc, tol=1e-8, ntries=100):
then satisfying any tolerance will let the check pass.
Default: 0
"""
_check_consistency(op, loc, tol, ntries, False)
_check_consistency(op, loc, tol, ntries)
def check_value_gradient_metric_consistency(op, loc, tol=1e-8, ntries=100):
"""FIXME"""
_check_consistency(op, loc, tol, ntries, True)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import pytest
import nifty5 as ift
def _flat_PS(k):
return np.ones_like(k)
pmp = pytest.mark.parametrize
@pmp('space', [
ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)
])
@pmp('nonlinearity', ["tanh", "exp", ""])
@pmp('noise', [1, 1e-2, 1e2])
@pmp('seed', [4, 78, 23])
def test_gaussian_energy(space, nonlinearity, noise, seed):
np.random.seed(seed)
dim = len(space.shape)
hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, target=space)
binbounds = ift.PowerSpace.useful_binbounds(hspace, logarithmic=False)
pspace = ift.PowerSpace(hspace, binbounds=binbounds)
Dist = ift.PowerDistributor(target=hspace, power_space=pspace)
xi0 = ift.Field.from_random(domain=hspace, random_type='normal')
def pspec(k):
return 1/(1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
A = Dist(ift.sqrt(pspec))
N = ift.ScalingOperator(noise, space)
n = N.draw_sample()
R = ift.ScalingOperator(10., space)
def d_model():
if nonlinearity == "":
return R(ht(ift.makeOp(A)))
else:
tmp = ht(ift.makeOp(A))
nonlin = getattr(tmp, nonlinearity)()
return R(nonlin)
d = d_model()(xi0) + n
if noise == 1:
N = None
energy = ift.GaussianEnergy(d, N)(d_model())
ift.extra.check_value_gradient_consistency(
energy, xi0, ntries=10, tol=5e-8)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment