diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b640db93abb04d1e07d59fc0eae7d9c75eba632a..5008449be4eb1c53ded405dc3a0c4558d79a9894 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -43,6 +43,13 @@ test_serial: - > grep TOTAL coverage.txt | awk '{ print "TOTAL: "$4; }' +test_mpi: + stage: test + variables: + OMPI_MCA_btl_vader_single_copy_mechanism: none + script: + - mpiexec -n 2 --bind-to none pytest-3 -q test/test_mpi + pages: stage: release script: diff --git a/test/test_mpi/__init__.py b/test/test_mpi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/test_mpi/test_kl.py b/test/test_mpi/test_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..3219a369d015754385e0821f1212162ebe57dd87 --- /dev/null +++ b/test/test_mpi/test_kl.py @@ -0,0 +1,100 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2019 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +import pytest +from mpi4py import MPI +from numpy.testing import assert_, assert_allclose + +import nifty6 as ift + +from ..common import setup_function, teardown_function + +comm = MPI.COMM_WORLD +ntask = comm.Get_size() +rank = comm.Get_rank() +master = (rank == 0) +mpi = ntask > 1 + +pmp = pytest.mark.parametrize +pms = pytest.mark.skipif + + +@pms(ntask != 2, reason="requires exactly two mpi tasks") +@pmp('constants', ([], ['a'], ['b'], ['a', 'b'])) +@pmp('point_estimates', ([], ['a'], ['b'], ['a', 'b'])) +@pmp('mirror_samples', (False, True)) +@pmp('mode', (0, 1)) +def test_kl(constants, point_estimates, mirror_samples, mode): + dom = ift.RGSpace((12,), (2.12)) + op0 = ift.HarmonicSmoothingOperator(dom, 3) + op = ift.ducktape(dom, None, 'a')*(op0.ducktape('b')) + lh = ift.GaussianEnergy(domain=op.target) @ op + ic = ift.GradientNormController(iteration_limit=5) + h = ift.StandardHamiltonian(lh, ic_samp=ic) + mean0 = ift.from_random('normal', h.domain) + nsamps = 2 + args = {'constants': constants, + 'point_estimates': point_estimates, + 'mirror_samples': mirror_samples, + 'n_samples': 2, + 'mean': mean0, + 'hamiltonian': h} + if mode == 0: + kl0 = ift.MetricGaussianKL(**args, comm=comm) + locsamp = kl0._local_samples + kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp) + elif mode == 1: + kl0 = ift.MetricGaussianKL(**args) + samples = kl0._local_samples + ii = len(samples)//2 + slc = slice(None, ii) if rank == 0 else slice(ii, None) + locsamp = samples[slc] + kl1 = ift.MetricGaussianKL(**args, comm=comm, _local_samples=locsamp) + + # Test value + assert_allclose(kl0.value, kl1.value) + + # Test gradient + for kk in h.domain.keys(): + res0 = kl0.gradient[kk].val + if kk in constants: + res0 = 0*res0 + res1 = kl1.gradient[kk].val + assert_allclose(res0, res1) + + # Test number of samples + expected_nsamps = 2*nsamps if mirror_samples else nsamps + assert_(len(tuple(kl0.samples)) == expected_nsamps) + assert_(len(tuple(kl1.samples)) == expected_nsamps) + + # Test point_estimates (after drawing samples) + for kk in point_estimates: + for ss in kl0.samples: + ss = ss[kk].val + assert_allclose(ss, 0*ss) + for ss in kl1.samples: + ss = ss[kk].val + assert_allclose(ss, 0*ss) + + # Test constants (after some minimization) + cg = ift.GradientNormController(iteration_limit=5) + minimizer = ift.NewtonCG(cg) + for e in [kl0, kl1]: + e, _ = minimizer(e) + diff = (mean0 - e.position).to_dict() + for kk in constants: + assert_allclose(diff[kk].val, 0*diff[kk].val)