From 90bcde2c580736f992308efdbd8105f83c43725f Mon Sep 17 00:00:00 2001 From: Reimar Leike <reimar@mpa-garhcing.mpg.de> Date: Fri, 19 Jun 2020 17:32:31 +0200 Subject: [PATCH] Add test for VariableCovarianceGaussianEnergy --- test/test_operators/test_fisher_metric.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/test/test_operators/test_fisher_metric.py b/test/test_operators/test_fisher_metric.py index 6700bd022..1308cee50 100644 --- a/test/test_operators/test_fisher_metric.py +++ b/test/test_operators/test_fisher_metric.py @@ -29,6 +29,8 @@ pmp = pytest.mark.parametrize field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] + [ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces]) +dtype = list2fixture([np.float64, + np.complex128]) Nsamp = 2000 np.random.seed(42) @@ -110,7 +112,7 @@ def test_GaussianEnergy(field): icov = ift.makeOp(icov) get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype( from_inverse=True, dtype=dtype) - E_init = lambda mean: ift.GaussianEnergy(mean=mean, inverse_covariance=icov) + E_init = lambda data: ift.GaussianEnergy(mean=data, inverse_covariance=icov) energy_tester(field, get_noisy_data, E_init) @@ -122,5 +124,20 @@ def test_PoissonEnergy(field): get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val)) # Make rate positive and high enough to avoid bad statistic lam = 10*(field**2).clip(0.1, None) - E_init = lambda mean: ift.PoissonianEnergy(mean) + E_init = lambda data: ift.PoissonianEnergy(data) energy_tester(lam, get_noisy_data, E_init) + +def test_VariableCovarianceGaussianEnergy(dtype): + dom = ift.UnstructuredDomain(3) + res = ift.from_random(dom, 'normal', dtype=dtype) + ivar = ift.from_random(dom, 'normal')**2+4. + mf = ift.MultiField.from_dict({'res':res, 'ivar':ivar}) + energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype) + def get_noisy_data(mean): + samp = ift.from_random(dom, 'normal', dtype) + samp = samp/mean['ivar'].sqrt() + return samp + mean['res'] + def E_init(data): + adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True) + return energy.partial_insert(adder) + energy_tester(mf, get_noisy_data, E_init) -- GitLab