Skip to content
Snippets Groups Projects
Commit 31426c60 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fixup untested code in KL

parent 91f9c9ff
Branches
Tags
1 merge request!453Fix mpi kl
Pipeline #73448 failed
...@@ -50,7 +50,7 @@ def _allreduce_sum_field(comm, fld): ...@@ -50,7 +50,7 @@ def _allreduce_sum_field(comm, fld):
if comm is None: if comm is None:
return fld return fld
if isinstance(fld, Field): if isinstance(fld, Field):
return Field(fld.domain, _np_allreduce_sum(fld.val)) return Field(fld.domain, _np_allreduce_sum(comm, fld.val))
res = tuple( res = tuple(
Field(f.domain, _np_allreduce_sum(comm, f.val)) Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values()) for f in fld.values())
......
...@@ -40,6 +40,8 @@ pms = pytest.mark.skipif ...@@ -40,6 +40,8 @@ pms = pytest.mark.skipif
@pmp('mode', (0, 1)) @pmp('mode', (0, 1))
@pmp('mf', (False, True)) @pmp('mf', (False, True))
def test_kl(constants, point_estimates, mirror_samples, mode, mf): def test_kl(constants, point_estimates, mirror_samples, mode, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12)) dom = ift.RGSpace((12,), (2.12))
op = ift.HarmonicSmoothingOperator(dom, 3) op = ift.HarmonicSmoothingOperator(dom, 3)
if mf: if mf:
...@@ -67,10 +69,18 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf): ...@@ -67,10 +69,18 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
locsamp = samples[slc] locsamp = samples[slc]
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp) kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test value # Test value
assert_allclose(kl0.value, kl1.value) assert_allclose(kl0.value, kl1.value)
# Test gradient # Test gradient
if not mf:
ift.extra.assert_allclose(kl0.gradient, kl1.gradient, 0, 1e-14)
return
for kk in h.domain.keys(): for kk in h.domain.keys():
res0 = kl0.gradient[kk].val res0 = kl0.gradient[kk].val
if kk in constants: if kk in constants:
...@@ -78,11 +88,6 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf): ...@@ -78,11 +88,6 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
res1 = kl1.gradient[kk].val res1 = kl1.gradient[kk].val
assert_allclose(res0, res1) assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples) # Test point_estimates (after drawing samples)
for kk in point_estimates: for kk in point_estimates:
for ss in kl0.samples: for ss in kl0.samples:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment