Commit 90bcde2c authored by Reimar Leike's avatar Reimar Leike
Browse files

Add test for VariableCovarianceGaussianEnergy

parent 06d728fc
Pipeline #76989 passed with stages
in 12 minutes and 37 seconds
...@@ -29,6 +29,8 @@ pmp = pytest.mark.parametrize ...@@ -29,6 +29,8 @@ pmp = pytest.mark.parametrize
field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] + field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
[ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces]) [ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces])
dtype = list2fixture([np.float64,
np.complex128])
Nsamp = 2000 Nsamp = 2000
np.random.seed(42) np.random.seed(42)
...@@ -110,7 +112,7 @@ def test_GaussianEnergy(field): ...@@ -110,7 +112,7 @@ def test_GaussianEnergy(field):
icov = ift.makeOp(icov) icov = ift.makeOp(icov)
get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype( get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype(
from_inverse=True, dtype=dtype) from_inverse=True, dtype=dtype)
E_init = lambda mean: ift.GaussianEnergy(mean=mean, inverse_covariance=icov) E_init = lambda data: ift.GaussianEnergy(mean=data, inverse_covariance=icov)
energy_tester(field, get_noisy_data, E_init) energy_tester(field, get_noisy_data, E_init)
...@@ -122,5 +124,20 @@ def test_PoissonEnergy(field): ...@@ -122,5 +124,20 @@ def test_PoissonEnergy(field):
get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val)) get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val))
# Make rate positive and high enough to avoid bad statistic # Make rate positive and high enough to avoid bad statistic
lam = 10*(field**2).clip(0.1, None) lam = 10*(field**2).clip(0.1, None)
E_init = lambda mean: ift.PoissonianEnergy(mean) E_init = lambda data: ift.PoissonianEnergy(data)
energy_tester(lam, get_noisy_data, E_init) energy_tester(lam, get_noisy_data, E_init)
def test_VariableCovarianceGaussianEnergy(dtype):
dom = ift.UnstructuredDomain(3)
res = ift.from_random(dom, 'normal', dtype=dtype)
ivar = ift.from_random(dom, 'normal')**2+4.
mf = ift.MultiField.from_dict({'res':res, 'ivar':ivar})
energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype)
def get_noisy_data(mean):
samp = ift.from_random(dom, 'normal', dtype)
samp = samp/mean['ivar'].sqrt()
return samp + mean['res']
def E_init(data):
adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True)
return energy.partial_insert(adder)
energy_tester(mf, get_noisy_data, E_init)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment