Commit bdedf261 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix tests

parent 31426c60
Pipeline #73463 passed with stages
in 19 minutes and 51 seconds
......@@ -30,6 +30,8 @@ pmp = pytest.mark.parametrize
@pmp('mirror_samples', (True, False))
@pmp('mf', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12))
op = ift.HarmonicSmoothingOperator(dom, 3)
if mf:
......@@ -55,10 +57,18 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
napprox=0,
_local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test value
assert_allclose(kl.value, klpure.value)
# Test gradient
if not mf:
ift.extra.assert_allclose(kl.gradient, klpure.gradient, 0, 1e-14)
return
for kk in h.domain.keys():
res0 = klpure.gradient[kk].val
if kk in constants:
......@@ -66,10 +76,6 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
res1 = kl.gradient[kk].val
assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl.samples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment