Commit d99be923 authored by Philipp Arras's avatar Philipp Arras
Browse files

Formatting

parent 11210943
Pipeline #76980 passed with stages
in 12 minutes and 13 seconds
......@@ -26,17 +26,19 @@ spaces = [ift.GLSpace(5),
(ift.RGSpace(3, distances=.789), ift.UnstructuredDomain(2))]
pmp = pytest.mark.parametrize
field = list2fixture([ift.from_random(sp, 'normal') for sp in spaces] +
[ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces])
[ift.from_random(sp, 'normal', dtype=np.complex128) for sp in spaces])
Nsamp = 2000
np.random.seed(42)
def _to_array(d):
if isinstance(d, np.ndarray):
return d
assert isinstance(d, dict)
return np.concatenate(list(d.values()))
def _complex2real(sp):
tup = tuple([d for d in sp])
rsp = ift.DomainTuple.make((ift.UnstructuredDomain(2),) + tup)
......@@ -45,6 +47,7 @@ def _complex2real(sp):
x = ift.ScalingOperator(sp, 1)
return rl(x.real)+im(x.imag)
def test_complex2real():
sp = ift.UnstructuredDomain(3)
op = _complex2real(ift.makeDomain(sp))
......@@ -53,14 +56,16 @@ def test_complex2real():
assert op(f).dtype == np.float64
f = ift.from_random(op.target, 'normal')
assert np.all((f == op(op.adjoint_times(f))).val)
def energy_tester_complex(pos, get_noisy_data, energy_initializer):
op = _complex2real(pos.domain)
npos = op(pos)
nget_noisy_data = lambda mean : get_noisy_data(op.adjoint_times(mean))
nenergy_initializer = lambda mean : energy_initializer(mean) @ op.adjoint
nget_noisy_data = lambda mean: get_noisy_data(op.adjoint_times(mean))
nenergy_initializer = lambda mean: energy_initializer(mean) @ op.adjoint
energy_tester(npos, nget_noisy_data, nenergy_initializer)
def energy_tester(pos, get_noisy_data, energy_initializer):
if np.issubdtype(pos.dtype, np.complexfloating):
energy_tester_complex(pos, get_noisy_data, energy_initializer)
......@@ -83,25 +88,24 @@ def energy_tester(pos, get_noisy_data, energy_initializer):
res2 = _to_array(energy(lin).metric(test_vec).val)
np.testing.assert_allclose(res/std, res2/std, atol=6)
def test_GaussianEnergy(field):
dtype = field.dtype
icov = ift.from_random(field.domain, 'normal')**2
icov = ift.makeOp(icov)
get_noisy_data = lambda mean : mean + icov.draw_sample_with_dtype(
from_inverse=True, dtype=dtype)
E_init = lambda mean : ift.GaussianEnergy(mean=mean,
inverse_covariance=icov)
get_noisy_data = lambda mean: mean + icov.draw_sample_with_dtype(
from_inverse=True, dtype=dtype)
E_init = lambda mean: ift.GaussianEnergy(mean=mean, inverse_covariance=icov)
energy_tester(field, get_noisy_data, E_init)
def test_PoissonEnergy(field):
if not isinstance(field, ift.Field):
return
if np.iscomplexobj(field.val):
return
def get_noisy_data(mean):
return ift.makeField(mean.domain, np.random.poisson(mean.val))
lam = 10*(field**2).clip(0.1,None) # make rate positive and high enough to avoid bad statistic
E_init = lambda mean : ift.PoissonianEnergy(mean)
get_noisy_data = lambda mean: ift.makeField(mean.domain, np.random.poisson(mean.val))
# Make rate positive and high enough to avoid bad statistic
lam = 10*(field**2).clip(0.1, None)
E_init = lambda mean: ift.PoissonianEnergy(mean)
energy_tester(lam, get_noisy_data, E_init)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment