Skip to content
Snippets Groups Projects
Commit 743e98c9 authored by Philipp Arras's avatar Philipp Arras
Browse files

MultiField tests for energies

parent 2c338fb9
No related branches found
No related tags found
1 merge request!417Performance pa
Pipeline #70463 passed
...@@ -28,26 +28,18 @@ from itertools import product ...@@ -28,26 +28,18 @@ from itertools import product
SPACES = [ift.GLSpace(15), SPACES = [ift.GLSpace(15),
ift.RGSpace(64, distances=.789), ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)] ift.RGSpace([32, 32], distances=.789)]
for sp in SPACES[:3]:
SPACES.append(ift.MultiDomain.make({'asdf': sp}))
SEEDS = [4, 78, 23] SEEDS = [4, 78, 23]
PARAMS = product(SEEDS, SPACES) PARAMS = product(SEEDS, SPACES)
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
# FIXME Test also with multifields in domain
@pytest.fixture(params=PARAMS) @pytest.fixture(params=PARAMS)
def field(request): def field(request):
np.random.seed(request.param[0]) np.random.seed(request.param[0])
S = ift.ScalingOperator(request.param[1], 1.) S = ift.ScalingOperator(request.param[1], 1.)
s = S.draw_sample() return S.draw_sample()
return ift.MultiField.from_dict({'s1': s})['s1']
def test_variablecovariancegaussian(field):
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_gaussian(field): def test_gaussian(field):
...@@ -55,26 +47,52 @@ def test_gaussian(field): ...@@ -55,26 +47,52 @@ def test_gaussian(field):
ift.extra.check_jacobian_consistency(energy, field) ift.extra.check_jacobian_consistency(energy, field)
@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.), def test_ScaledEnergy(field):
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))]) icov = ift.ScalingOperator(field.domain, 1.2)
def test_ScaledEnergy(field, icov):
icov = icov(field.domain)
energy = ift.GaussianEnergy(inverse_covariance=icov) energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field) ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True) lin = ift.Linearization.make_var(field, want_metric=True)
met1 = energy(lin).metric met1 = energy(lin).metric
met2 = energy.scale(0.3)(lin).metric met2 = energy.scale(0.3)(lin).metric
np.testing.assert_allclose(met1(field).val, met2(field).val/0.3, rtol=1e-12) res1 = met1(field)
res2 = met2(field)/0.3
ift.extra.assert_allclose(res1, res2, 0, 1e-12)
met2.draw_sample() met2.draw_sample()
def test_studentt(field): def test_studentt(field):
if isinstance(field.domain, ift.MultiDomain):
return
energy = ift.StudentTEnergy(domain=field.domain, theta=.5) energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_jacobian_consistency(energy, field, tol=1e-6) ift.extra.check_jacobian_consistency(energy, field, tol=1e-6)
def test_hamiltonian_and_KL(field):
field = field.exp()
space = field.domain
lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_jacobian_consistency(hamiltonian, field)
S = ift.ScalingOperator(space, 1.)
samps = [S.draw_sample() for i in range(3)]
kl = ift.AveragedEnergy(hamiltonian, samps)
ift.extra.check_jacobian_consistency(kl, field)
def test_variablecovariancegaussian(field):
if isinstance(field.domain, ift.MultiDomain):
return
dc = {'a': field, 'b': field.exp()}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
energy(ift.Linearization.make_var(mf, want_metric=True)).metric.draw_sample()
def test_inverse_gamma(field): def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp() field = field.exp()
space = field.domain space = field.domain
d = np.random.normal(10, size=space.shape)**2 d = np.random.normal(10, size=space.shape)**2
...@@ -84,6 +102,8 @@ def test_inverse_gamma(field): ...@@ -84,6 +102,8 @@ def test_inverse_gamma(field):
def testPoissonian(field): def testPoissonian(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp() field = field.exp()
space = field.domain space = field.domain
d = np.random.poisson(120, size=space.shape) d = np.random.poisson(120, size=space.shape)
...@@ -92,19 +112,9 @@ def testPoissonian(field): ...@@ -92,19 +112,9 @@ def testPoissonian(field):
ift.extra.check_jacobian_consistency(energy, field, tol=1e-7) ift.extra.check_jacobian_consistency(energy, field, tol=1e-7)
def test_hamiltonian_and_KL(field):
field = field.exp()
space = field.domain
lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh)
ift.extra.check_jacobian_consistency(hamiltonian, field)
S = ift.ScalingOperator(space, 1.)
samps = [S.draw_sample() for i in range(3)]
kl = ift.AveragedEnergy(hamiltonian, samps)
ift.extra.check_jacobian_consistency(kl, field)
def test_bernoulli(field): def test_bernoulli(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.sigmoid() field = field.sigmoid()
space = field.domain space = field.domain
d = np.random.binomial(1, 0.1, size=space.shape) d = np.random.binomial(1, 0.1, size=space.shape)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment