Skip to content
Snippets Groups Projects
Commit 743e98c9 authored by Philipp Arras's avatar Philipp Arras
Browse files

MultiField tests for energies

parent 2c338fb9
No related branches found
No related tags found
1 merge request!417Performance pa
Pipeline #70463 passed
......@@ -28,26 +28,18 @@ from itertools import product
SPACES = [ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)]
for sp in SPACES[:3]:
SPACES.append(ift.MultiDomain.make({'asdf': sp}))
SEEDS = [4, 78, 23]
PARAMS = product(SEEDS, SPACES)
pmp = pytest.mark.parametrize
# FIXME Test also with multifields in domain
@pytest.fixture(params=PARAMS)
def field(request):
np.random.seed(request.param[0])
S = ift.ScalingOperator(request.param[1], 1.)
s = S.draw_sample()
return ift.MultiField.from_dict({'s1': s})['s1']
def test_variablecovariancegaussian(field):
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
return S.draw_sample()
def test_gaussian(field):
......@@ -55,26 +47,52 @@ def test_gaussian(field):
ift.extra.check_jacobian_consistency(energy, field)
@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
def test_ScaledEnergy(field, icov):
icov = icov(field.domain)
def test_ScaledEnergy(field):
icov = ift.ScalingOperator(field.domain, 1.2)
energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True)
met1 = energy(lin).metric
met2 = energy.scale(0.3)(lin).metric
np.testing.assert_allclose(met1(field).val, met2(field).val/0.3, rtol=1e-12)
res1 = met1(field)
res2 = met2(field)/0.3
ift.extra.assert_allclose(res1, res2, 0, 1e-12)
met2.draw_sample()
def test_studentt(field):
if isinstance(field.domain, ift.MultiDomain):
return
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
def test_hamiltonian_and_KL(field):
field = field.exp()
space = field.domain
lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_jacobian_consistency(hamiltonian, field)
S = ift.ScalingOperator(space, 1.)
samps = [S.draw_sample() for i in range(3)]
kl = ift.AveragedEnergy(hamiltonian, samps)
ift.extra.check_jacobian_consistency(kl, field)
def test_variablecovariancegaussian(field):
if isinstance(field.domain, ift.MultiDomain):
return
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp()
space = field.domain
d = np.random.normal(10, size=space.shape)**2
......@@ -84,6 +102,8 @@ def test_inverse_gamma(field):
def testPoissonian(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp()
space = field.domain
d = np.random.poisson(120, size=space.shape)
......@@ -92,19 +112,9 @@ def testPoissonian(field):
ift.extra.check_jacobian_consistency(energy, field, tol=1e-7)
def test_hamiltonian_and_KL(field):
field = field.exp()
space = field.domain
lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_jacobian_consistency(hamiltonian, field)
S = ift.ScalingOperator(space, 1.)
samps = [S.draw_sample() for i in range(3)]
kl = ift.AveragedEnergy(hamiltonian, samps)
ift.extra.check_jacobian_consistency(kl, field)
def test_bernoulli(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.sigmoid()
space = field.domain
d = np.random.binomial(1, 0.1, size=space.shape)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment