Commit bdedf261 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix tests

parent 31426c60
Pipeline #73463 passed with stages
in 19 minutes and 51 seconds
...@@ -30,6 +30,8 @@ pmp = pytest.mark.parametrize ...@@ -30,6 +30,8 @@ pmp = pytest.mark.parametrize
@pmp('mirror_samples', (True, False)) @pmp('mirror_samples', (True, False))
@pmp('mf', (True, False)) @pmp('mf', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf): def test_kl(constants, point_estimates, mirror_samples, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12)) dom = ift.RGSpace((12,), (2.12))
op = ift.HarmonicSmoothingOperator(dom, 3) op = ift.HarmonicSmoothingOperator(dom, 3)
if mf: if mf:
...@@ -55,10 +57,18 @@ def test_kl(constants, point_estimates, mirror_samples, mf): ...@@ -55,10 +57,18 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
napprox=0, napprox=0,
_local_samples=locsamp) _local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test value # Test value
assert_allclose(kl.value, klpure.value) assert_allclose(kl.value, klpure.value)
# Test gradient # Test gradient
if not mf:
ift.extra.assert_allclose(kl.gradient, klpure.gradient, 0, 1e-14)
return
for kk in h.domain.keys(): for kk in h.domain.keys():
res0 = klpure.gradient[kk].val res0 = klpure.gradient[kk].val
if kk in constants: if kk in constants:
...@@ -66,10 +76,6 @@ def test_kl(constants, point_estimates, mirror_samples, mf): ...@@ -66,10 +76,6 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
res1 = kl.gradient[kk].val res1 = kl.gradient[kk].val
assert_allclose(res0, res1) assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples) # Test point_estimates (after drawing samples)
for kk in point_estimates: for kk in point_estimates:
for ss in kl.samples: for ss in kl.samples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment