Commit 66c76e93 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'fix_mpi_kl' into 'NIFTy_6'

Fix mpi kl

See merge request !453
parents 959bc62a b1a9f4fa
Pipeline #73469 passed with stages
in 23 minutes and 8 seconds
......@@ -50,7 +50,7 @@ def _allreduce_sum_field(comm, fld):
if comm is None:
return fld
if isinstance(fld, Field):
return Field(fld.domain, _np_allreduce_sum(fld.val))
return Field(fld.domain, _np_allreduce_sum(comm, fld.val))
res = tuple(
Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values())
......
......@@ -34,9 +34,9 @@ class QuadraticEnergy(Energy):
else:
Ax = self._A(self._position)
self._grad = Ax if b is None else Ax - b
self._value = 0.5*self._position.s_vdot(Ax)
self._value = 0.5*self._position.s_vdot(Ax).real
if b is not None:
self._value -= b.s_vdot(self._position)
self._value -= b.s_vdot(self._position).real
def at(self, position):
return QuadraticEnergy(position, self._A, self._b)
......
......@@ -15,8 +15,6 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty6 as ift
from numpy.testing import assert_, assert_allclose
import pytest
......@@ -28,10 +26,14 @@ pmp = pytest.mark.parametrize
@pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False))
def test_kl(constants, point_estimates, mirror_samples):
@pmp('mf', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b'))
op = ift.HarmonicSmoothingOperator(dom, 3)
if mf:
op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
lh = ift.GaussianEnergy(domain=op.target) @ op
ic = ift.GradientNormController(iteration_limit=5)
h = ift.StandardHamiltonian(lh, ic_samp=ic)
......@@ -53,10 +55,18 @@ def test_kl(constants, point_estimates, mirror_samples):
napprox=0,
_local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test value
assert_allclose(kl.value, klpure.value)
# Test gradient
if not mf:
ift.extra.assert_allclose(kl.gradient, klpure.gradient, 0, 1e-14)
return
for kk in h.domain.keys():
res0 = klpure.gradient[kk].val
if kk in constants:
......@@ -64,10 +74,6 @@ def test_kl(constants, point_estimates, mirror_samples):
res1 = kl.gradient[kk].val
assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl.samples:
......
......@@ -38,10 +38,14 @@ pms = pytest.mark.skipif
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (False, True))
@pmp('mode', (0, 1))
def test_kl(constants, point_estimates, mirror_samples, mode):
@pmp('mf', (False, True))
def test_kl(constants, point_estimates, mirror_samples, mode, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b'))
op = ift.HarmonicSmoothingOperator(dom, 3)
if mf:
op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
lh = ift.GaussianEnergy(domain=op.target) @ op
ic = ift.GradientNormController(iteration_limit=5)
h = ift.StandardHamiltonian(lh, ic_samp=ic)
......@@ -65,10 +69,18 @@ def test_kl(constants, point_estimates, mirror_samples, mode):
locsamp = samples[slc]
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test value
assert_allclose(kl0.value, kl1.value)
# Test gradient
if not mf:
ift.extra.assert_allclose(kl0.gradient, kl1.gradient, 0, 1e-14)
return
for kk in h.domain.keys():
res0 = kl0.gradient[kk].val
if kk in constants:
......@@ -76,11 +88,6 @@ def test_kl(constants, point_estimates, mirror_samples, mode):
res1 = kl1.gradient[kk].val
assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
for ss in kl0.samples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment