diff --git a/test/test_energy_gradients.py b/test/test_energy_gradients.py index 117fc85c8d84d5aca9c0bb5df2b12e1e8127f761..32ce00e24601d2f663a2c0a56ab89bd8d71e1294 100644 --- a/test/test_energy_gradients.py +++ b/test/test_energy_gradients.py @@ -15,28 +15,19 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from itertools import product - import numpy as np import pytest import nifty6 as ift -from .common import setup_function, teardown_function - -# Currently it is not possible to parametrize fixtures. But this will -# hopefully be fixed in the future. -# https://docs.pytest.org/en/latest/proposals/parametrize_with_fixtures.html +from .common import list2fixture, setup_function, teardown_function -spaces = [ift.GLSpace(15), - ift.MultiDomain.make({'': ift.RGSpace(64, distances=.789)}), - (ift.RGSpace(4, distances=.789), ift.UnstructuredDomain(3))] +spaces = [ift.GLSpace(5), + ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}), + (ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))] pmp = pytest.mark.parametrize - - -@pytest.fixture(params=[spaces]) -def field(request): - return ift.from_random('normal', request.param[0]) +field = list2fixture([ift.from_random('normal', sp) for sp in spaces]) +ntries = 10 def test_gaussian(field): @@ -73,7 +64,7 @@ def test_studentt(field): ift.extra.check_jacobian_consistency(energy, field, tol=1e-6) theta = ift.from_random('normal',field.domain).exp() energy = ift.StudentTEnergy(domain=field.domain, theta=theta) - ift.extra.check_jacobian_consistency(energy, field, tol=1e-6) + ift.extra.check_jacobian_consistency(energy, field, tol=1e-6, ntries=ntries) def test_hamiltonian_and_KL(field): @@ -81,10 +72,10 @@ def test_hamiltonian_and_KL(field): space = field.domain lh = ift.GaussianEnergy(domain=space) hamiltonian = ift.StandardHamiltonian(lh) - ift.extra.check_jacobian_consistency(hamiltonian, field) - samps = [ift.from_random('normal', space) for i in range(3)] + ift.extra.check_jacobian_consistency(hamiltonian, field, ntries=ntries) + samps = [ift.from_random('normal', space) for i in range(2)] kl = ift.AveragedEnergy(hamiltonian, samps) - ift.extra.check_jacobian_consistency(kl, field) + ift.extra.check_jacobian_consistency(kl, field, ntries=ntries) def test_variablecovariancegaussian(field): @@ -93,7 +84,7 @@ def test_variablecovariancegaussian(field): dc = {'a': field, 'b': field.ptw("exp")} mf = ift.MultiField.from_dict(dc) energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b', np.float64) - ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6) + ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6, ntries=ntries) energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()