Commit 31426c60 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fixup untested code in KL

parent 91f9c9ff
Pipeline #73448 failed with stages
in 14 minutes and 33 seconds
......@@ -50,7 +50,7 @@ def _allreduce_sum_field(comm, fld):
if comm is None:
return fld
if isinstance(fld, Field):
return Field(fld.domain, _np_allreduce_sum(fld.val))
return Field(fld.domain, _np_allreduce_sum(comm, fld.val))
res = tuple(
Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values())
......
......@@ -40,6 +40,8 @@ pms = pytest.mark.skipif
@pmp('mode', (0, 1))
@pmp('mf', (False, True))
def test_kl(constants, point_estimates, mirror_samples, mode, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12))
op = ift.HarmonicSmoothingOperator(dom, 3)
if mf:
......@@ -67,10 +69,18 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
locsamp = samples[slc]
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test value
assert_allclose(kl0.value, kl1.value)
# Test gradient
if not mf:
ift.extra.assert_allclose(kl0.gradient, kl1.gradient, 0, 1e-14)
return
for kk in h.domain.keys():
res0 = kl0.gradient[kk].val
if kk in constants:
......@@ -78,11 +88,6 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
res1 = kl1.gradient[kk].val
assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl0.samples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment