Skip to content
Snippets Groups Projects
Commit 90bcde2c authored by Reimar Leike's avatar Reimar Leike
Browse files

Add test for VariableCovarianceGaussianEnergy

parent 06d728fc
No related branches found
No related tags found
1 merge request!543Nifty627
Pipeline #76989 passed
...@@ -29,6 +29,8 @@ pmp = pytest.mark.parametrize ...@@ -29,6 +29,8 @@ pmp = pytest.mark.parametrize
field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] + field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
[ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces]) [ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces])
dtype = list2fixture([np.float64,
np.complex128])
Nsamp = 2000 Nsamp = 2000
np.random.seed(42) np.random.seed(42)
...@@ -110,7 +112,7 @@ def test_GaussianEnergy(field): ...@@ -110,7 +112,7 @@ def test_GaussianEnergy(field):
icov = ift.makeOp(icov) icov = ift.makeOp(icov)
get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype( get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype(
from_inverse=True, dtype=dtype) from_inverse=True, dtype=dtype)
E_init = lambda mean: ift.GaussianEnergy(mean=mean, inverse_covariance=icov) E_init = lambda data: ift.GaussianEnergy(mean=data, inverse_covariance=icov)
energy_tester(field, get_noisy_data, E_init) energy_tester(field, get_noisy_data, E_init)
...@@ -122,5 +124,20 @@ def test_PoissonEnergy(field): ...@@ -122,5 +124,20 @@ def test_PoissonEnergy(field):
get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val)) get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val))
# Make rate positive and high enough to avoid bad statistic # Make rate positive and high enough to avoid bad statistic
lam = 10*(field**2).clip(0.1, None) lam = 10*(field**2).clip(0.1, None)
E_init = lambda mean: ift.PoissonianEnergy(mean) E_init = lambda data: ift.PoissonianEnergy(data)
energy_tester(lam, get_noisy_data, E_init) energy_tester(lam, get_noisy_data, E_init)
def test_VariableCovarianceGaussianEnergy(dtype):
dom = ift.UnstructuredDomain(3)
res = ift.from_random(dom, 'normal', dtype=dtype)
ivar = ift.from_random(dom, 'normal')**2+4.
mf = ift.MultiField.from_dict({'res':res, 'ivar':ivar})
energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype)
def get_noisy_data(mean):
samp = ift.from_random(dom, 'normal', dtype)
samp = samp/mean['ivar'].sqrt()
return samp + mean['res']
def E_init(data):
adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True)
return energy.partial_insert(adder)
energy_tester(mf, get_noisy_data, E_init)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment