Commit 8087f286 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'improvetesting' into 'NIFTy_7'

Improvetesting

See merge request !559
parents 64a02e2c b8da1c47
Pipeline #78582 passed with stages
in 26 minutes and 11 seconds
......@@ -15,13 +15,12 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from unittest import SkipTest
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_equal
import nifty7 as ift
from .common import setup_function, teardown_function
pmp = pytest.mark.parametrize
......@@ -64,7 +63,7 @@ def test_quadratic_minimization(minimizer, space):
(energy, convergence) = minimizer(energy)
except NotImplementedError:
raise SkipTest
pytest.skip()
assert_equal(convergence, IC.CONVERGED)
assert_allclose(
......@@ -123,7 +122,7 @@ def test_rosenbrock(minimizer):
try:
from scipy.optimize import rosen, rosen_der, rosen_hess_prod
except ImportError:
raise SkipTest
pytest.skip()
space = ift.DomainTuple.make(ift.UnstructuredDomain((2,)))
starting_point = ift.Field.from_random(domain=space, random_type='normal') * 10
......@@ -171,7 +170,7 @@ def test_rosenbrock(minimizer):
(energy, convergence) = minimizer(energy)
except NotImplementedError:
raise SkipTest
pytest.skip()
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.val, 1., rtol=1e-3, atol=1e-3)
......@@ -208,7 +207,7 @@ def test_gauss(minimizer):
(energy, convergence) = minimizer(energy)
except NotImplementedError:
raise SkipTest
pytest.skip()
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.val, 0., atol=1e-3)
......@@ -252,7 +251,7 @@ def test_cosh(minimizer):
(energy, convergence) = minimizer(energy)
except NotImplementedError:
raise SkipTest
pytest.skip()
assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.val, 0., atol=1e-3)
......@@ -96,8 +96,6 @@ def _actual_energy_tester(pos, get_noisy_data, energy_initializer):
energy = energy_initializer(data)
grad = energy(lin).jac.adjoint(ift.full(energy.target, 1.))
results.append(_to_array((grad*grad.s_vdot(test_vec)).val))
print(energy)
print(grad)
res = np.mean(np.array(results), axis=0)
std = np.std(np.array(results), axis=0)/np.sqrt(Nsamp)
energy = energy_initializer(data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment