Commit 66c76e93 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'fix_mpi_kl' into 'NIFTy_6'

Fix mpi kl

See merge request !453
parents 959bc62a b1a9f4fa
Pipeline #73469 passed with stages
in 23 minutes and 8 seconds
...@@ -50,7 +50,7 @@ def _allreduce_sum_field(comm, fld): ...@@ -50,7 +50,7 @@ def _allreduce_sum_field(comm, fld):
if comm is None: if comm is None:
return fld return fld
if isinstance(fld, Field): if isinstance(fld, Field):
return Field(fld.domain, _np_allreduce_sum(fld.val)) return Field(fld.domain, _np_allreduce_sum(comm, fld.val))
res = tuple( res = tuple(
Field(f.domain, _np_allreduce_sum(comm, f.val)) Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values()) for f in fld.values())
......
...@@ -34,9 +34,9 @@ class QuadraticEnergy(Energy): ...@@ -34,9 +34,9 @@ class QuadraticEnergy(Energy):
else: else:
Ax = self._A(self._position) Ax = self._A(self._position)
self._grad = Ax if b is None else Ax - b self._grad = Ax if b is None else Ax - b
self._value = 0.5*self._position.s_vdot(Ax) self._value = 0.5*self._position.s_vdot(Ax).real
if b is not None: if b is not None:
self._value -= b.s_vdot(self._position) self._value -= b.s_vdot(self._position).real
def at(self, position): def at(self, position):
return QuadraticEnergy(position, self._A, self._b) return QuadraticEnergy(position, self._A, self._b)
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import nifty6 as ift import nifty6 as ift
from numpy.testing import assert_, assert_allclose from numpy.testing import assert_, assert_allclose
import pytest import pytest
...@@ -28,10 +26,14 @@ pmp = pytest.mark.parametrize ...@@ -28,10 +26,14 @@ pmp = pytest.mark.parametrize
@pmp('constants', ([], ['a'], ['b'], ['a', 'b'])) @pmp('constants', ([], ['a'], ['b'], ['a', 'b']))
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b'])) @pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (True, False)) @pmp('mirror_samples', (True, False))
def test_kl(constants, point_estimates, mirror_samples): @pmp('mf', (True, False))
def test_kl(constants, point_estimates, mirror_samples, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12)) dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3) op = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b')) if mf:
op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
lh = ift.GaussianEnergy(domain=op.target) @ op lh = ift.GaussianEnergy(domain=op.target) @ op
ic = ift.GradientNormController(iteration_limit=5) ic = ift.GradientNormController(iteration_limit=5)
h = ift.StandardHamiltonian(lh, ic_samp=ic) h = ift.StandardHamiltonian(lh, ic_samp=ic)
...@@ -53,10 +55,18 @@ def test_kl(constants, point_estimates, mirror_samples): ...@@ -53,10 +55,18 @@ def test_kl(constants, point_estimates, mirror_samples):
napprox=0, napprox=0,
_local_samples=locsamp) _local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test value # Test value
assert_allclose(kl.value, klpure.value) assert_allclose(kl.value, klpure.value)
# Test gradient # Test gradient
if not mf:
ift.extra.assert_allclose(kl.gradient, klpure.gradient, 0, 1e-14)
return
for kk in h.domain.keys(): for kk in h.domain.keys():
res0 = klpure.gradient[kk].val res0 = klpure.gradient[kk].val
if kk in constants: if kk in constants:
...@@ -64,10 +74,6 @@ def test_kl(constants, point_estimates, mirror_samples): ...@@ -64,10 +74,6 @@ def test_kl(constants, point_estimates, mirror_samples):
res1 = kl.gradient[kk].val res1 = kl.gradient[kk].val
assert_allclose(res0, res1) assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples) # Test point_estimates (after drawing samples)
for kk in point_estimates: for kk in point_estimates:
for ss in kl.samples: for ss in kl.samples:
......
...@@ -38,10 +38,14 @@ pms = pytest.mark.skipif ...@@ -38,10 +38,14 @@ pms = pytest.mark.skipif
@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b'])) @pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b']))
@pmp('mirror_samples', (False, True)) @pmp('mirror_samples', (False, True))
@pmp('mode', (0, 1)) @pmp('mode', (0, 1))
def test_kl(constants, point_estimates, mirror_samples, mode): @pmp('mf', (False, True))
def test_kl(constants, point_estimates, mirror_samples, mode, mf):
if not mf and (len(point_estimates) != 0 or len(constants) != 0):
return
dom = ift.RGSpace((12,), (2.12)) dom = ift.RGSpace((12,), (2.12))
op0 = ift.HarmonicSmoothingOperator(dom, 3) op = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b')) if mf:
op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
lh = ift.GaussianEnergy(domain=op.target) @ op lh = ift.GaussianEnergy(domain=op.target) @ op
ic = ift.GradientNormController(iteration_limit=5) ic = ift.GradientNormController(iteration_limit=5)
h = ift.StandardHamiltonian(lh, ic_samp=ic) h = ift.StandardHamiltonian(lh, ic_samp=ic)
...@@ -65,10 +69,18 @@ def test_kl(constants, point_estimates, mirror_samples, mode): ...@@ -65,10 +69,18 @@ def test_kl(constants, point_estimates, mirror_samples, mode):
locsamp = samples[slc] locsamp = samples[slc]
kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp) kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test value # Test value
assert_allclose(kl0.value, kl1.value) assert_allclose(kl0.value, kl1.value)
# Test gradient # Test gradient
if not mf:
ift.extra.assert_allclose(kl0.gradient, kl1.gradient, 0, 1e-14)
return
for kk in h.domain.keys(): for kk in h.domain.keys():
res0 = kl0.gradient[kk].val res0 = kl0.gradient[kk].val
if kk in constants: if kk in constants:
...@@ -76,11 +88,6 @@ def test_kl(constants, point_estimates, mirror_samples, mode): ...@@ -76,11 +88,6 @@ def test_kl(constants, point_estimates, mirror_samples, mode):
res1 = kl1.gradient[kk].val res1 = kl1.gradient[kk].val
assert_allclose(res0, res1) assert_allclose(res0, res1)
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(tuple(kl0.samples)) == expected_nsamps)
assert_(len(tuple(kl1.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples) # Test point_estimates (after drawing samples)
for kk in point_estimates: for kk in point_estimates:
for ss in kl0.samples: for ss in kl0.samples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment