Commit 743e98c9 authored by Philipp Arras's avatar Philipp Arras
Browse files

MultiField tests for energies

parent 2c338fb9
Pipeline #70463 passed with stages
in 16 minutes and 57 seconds
......@@ -28,26 +28,18 @@ from itertools import product
SPACES = [ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)]
for sp in SPACES[:3]:
SPACES.append(ift.MultiDomain.make({'asdf': sp}))
SEEDS = [4, 78, 23]
PARAMS = product(SEEDS, SPACES)
pmp = pytest.mark.parametrize
# FIXME Test also with multifields in domain
@pytest.fixture(params=PARAMS)
def field(request):
np.random.seed(request.param[0])
S = ift.ScalingOperator(request.param[1], 1.)
s = S.draw_sample()
return ift.MultiField.from_dict({'s1': s})['s1']
def test_variablecovariancegaussian(field):
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
return S.draw_sample()
def test_gaussian(field):
......@@ -55,26 +47,52 @@ def test_gaussian(field):
ift.extra.check_jacobian_consistency(energy, field)
@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
def test_ScaledEnergy(field, icov):
icov = icov(field.domain)
def test_ScaledEnergy(field):
icov = ift.ScalingOperator(field.domain, 1.2)
energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True)
met1 = energy(lin).metric
met2 = energy.scale(0.3)(lin).metric
np.testing.assert_allclose(met1(field).val, met2(field).val/0.3, rtol=1e-12)
res1 = met1(field)
res2 = met2(field)/0.3
ift.extra.assert_allclose(res1, res2, 0, 1e-12)
met2.draw_sample()
def test_studentt(field):
if isinstance(field.domain, ift.MultiDomain):
return
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
def test_hamiltonian_and_KL(field):
field = field.exp()
space = field.domain
lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_jacobian_consistency(hamiltonian, field)
S = ift.ScalingOperator(space, 1.)
samps = [S.draw_sample() for i in range(3)]
kl = ift.AveragedEnergy(hamiltonian, samps)
ift.extra.check_jacobian_consistency(kl, field)
def test_variablecovariancegaussian(field):
if isinstance(field.domain, ift.MultiDomain):
return
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp()
space = field.domain
d = np.random.normal(10, size=space.shape)**2
......@@ -84,6 +102,8 @@ def test_inverse_gamma(field):
def testPoissonian(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp()
space = field.domain
d = np.random.poisson(120, size=space.shape)
......@@ -92,19 +112,9 @@ def testPoissonian(field):
ift.extra.check_jacobian_consistency(energy, field, tol=1e-7)
def test_hamiltonian_and_KL(field):
field = field.exp()
space = field.domain
lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_jacobian_consistency(hamiltonian, field)
S = ift.ScalingOperator(space, 1.)
samps = [S.draw_sample() for i in range(3)]
kl = ift.AveragedEnergy(hamiltonian, samps)
ift.extra.check_jacobian_consistency(kl, field)
def test_bernoulli(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.sigmoid()
space = field.domain
d = np.random.binomial(1, 0.1, size=space.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment