Commit e59b5d06 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add tests for debugging

parent fe8ec889
......@@ -153,7 +153,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
Data type of the samples. Usually either 'np.float*' or 'np.complex*'
"""
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype, _debugging_factor=1.):
self._kr = str(residual_key)
self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
......@@ -161,6 +161,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, self._dt)
self._cplx = _iscomplex(sampling_dtype)
self._factor = float(_debugging_factor)
def apply(self, x):
self._check_input(x)
......@@ -173,7 +174,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
return res
met = i.val if self._cplx else 0.5*i.val
# FIXME DO NOT MERGE THAT
met = MultiField.from_dict({self._kr: i.val, self._ki: 2*met**(-2)})
met = MultiField.from_dict({self._kr: i.val, self._ki: self._factor*met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
def _simplify_for_constant_input_nontrivial(self, c_inp):
......
......@@ -128,12 +128,16 @@ def test_PoissonEnergy(field):
E_init = lambda data: ift.PoissonianEnergy(data)
energy_tester(lam, get_noisy_data, E_init)
def test_VariableCovarianceGaussianEnergy(dtype):
@pytest.mark.parametrize('factor', [0.01, 0.5, 1, 2, 100])
def test_VariableCovarianceGaussianEnergy(dtype, factor):
if np.issubdtype(dtype, np.complexfloating):
pytest.skip()
dom = ift.UnstructuredDomain(3)
res = ift.from_random(dom, 'normal', dtype=dtype)
ivar = ift.from_random(dom, 'normal')**2+4.
mf = ift.MultiField.from_dict({'res':res, 'ivar':ivar})
energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype)
energy = ift.VariableCovarianceGaussianEnergy(dom, 'res', 'ivar', dtype, _debugging_factor=factor)
def get_noisy_data(mean):
samp = ift.from_random(dom, 'normal', dtype)
samp = samp/mean['ivar'].sqrt()
......@@ -141,4 +145,8 @@ def test_VariableCovarianceGaussianEnergy(dtype):
def E_init(data):
adder = ift.Adder(ift.MultiField.from_dict({'res':data}), neg=True)
return energy.partial_insert(adder)
energy_tester(mf, get_noisy_data, E_init)
if factor != 1.:
with pytest.raises(AssertionError):
energy_tester(mf, get_noisy_data, E_init)
else:
energy_tester(mf, get_noisy_data, E_init)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment