Skip to content
Snippets Groups Projects
Commit f45e9a35 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fixups

parent 92c36ed9
No related branches found
No related tags found
1 merge request!471Speedup tests
Pipeline #75157 passed
...@@ -15,28 +15,19 @@ ...@@ -15,28 +15,19 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from itertools import product
import numpy as np import numpy as np
import pytest import pytest
import nifty6 as ift import nifty6 as ift
from .common import setup_function, teardown_function from .common import list2fixture, setup_function, teardown_function
# Currently it is not possible to parametrize fixtures. But this will
# hopefully be fixed in the future.
# https://docs.pytest.org/en/latest/proposals/parametrize_with_fixtures.html
spaces = [ift.GLSpace(15), spaces = [ift.GLSpace(5),
ift.MultiDomain.make({'': ift.RGSpace(64, distances=.789)}), ift.MultiDomain.make({'': ift.RGSpace(5, distances=.789)}),
(ift.RGSpace(4, distances=.789), ift.UnstructuredDomain(3))] (ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))]
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
field = list2fixture([ift.from_random('normal', sp) for sp in spaces])
ntries = 10
@pytest.fixture(params=[spaces])
def field(request):
return ift.from_random('normal', request.param[0])
def test_gaussian(field): def test_gaussian(field):
...@@ -73,7 +64,7 @@ def test_studentt(field): ...@@ -73,7 +64,7 @@ def test_studentt(field):
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6) ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
theta = ift.from_random('normal',field.domain).exp() theta = ift.from_random('normal',field.domain).exp()
energy = ift.StudentTEnergy(domain=field.domain, theta=theta) energy = ift.StudentTEnergy(domain=field.domain, theta=theta)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6) ift.extra.check_jacobian_consistency(energy, field, tol=1e-6, ntries=ntries)
def test_hamiltonian_and_KL(field): def test_hamiltonian_and_KL(field):
...@@ -81,10 +72,10 @@ def test_hamiltonian_and_KL(field): ...@@ -81,10 +72,10 @@ def test_hamiltonian_and_KL(field):
space = field.domain space = field.domain
lh = ift.GaussianEnergy(domain=space) lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh) hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_jacobian_consistency(hamiltonian, field) ift.extra.check_jacobian_consistency(hamiltonian, field, ntries=ntries)
samps = [ift.from_random('normal', space) for i in range(3)] samps = [ift.from_random('normal', space) for i in range(2)]
kl = ift.AveragedEnergy(hamiltonian, samps) kl = ift.AveragedEnergy(hamiltonian, samps)
ift.extra.check_jacobian_consistency(kl, field) ift.extra.check_jacobian_consistency(kl, field, ntries=ntries)
def test_variablecovariancegaussian(field): def test_variablecovariancegaussian(field):
...@@ -93,7 +84,7 @@ def test_variablecovariancegaussian(field): ...@@ -93,7 +84,7 @@ def test_variablecovariancegaussian(field):
dc = {'a': field, 'b': field.ptw("exp")} dc = {'a': field, 'b': field.ptw("exp")}
mf = ift.MultiField.from_dict(dc) mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b', np.float64) energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b', np.float64)
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6) ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6, ntries=ntries)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample() energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment