From cbc3b1659cb43c6782db36abba2d9a35b1db0139 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 22 Feb 2019 13:17:58 +0100
Subject: [PATCH] re-introduce mpi_experiments

---
 nifty5/__init__.py                        |   2 +-
 nifty5/dobj.py                            |  19 ++--
 nifty5/minimization/metric_gaussian_kl.py | 127 ++++++++++++++++++++++
 3 files changed, 140 insertions(+), 8 deletions(-)

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index 2cb37c085..66a7c17ab 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -69,7 +69,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
 from .minimization.energy import Energy
 from .minimization.quadratic_energy import QuadraticEnergy
 from .minimization.energy_adapter import EnergyAdapter
-from .minimization.metric_gaussian_kl import MetricGaussianKL
+from .minimization.metric_gaussian_kl import MetricGaussianKL, MetricGaussianKL_MPI
 
 from .sugar import *
 from .plot import Plot
diff --git a/nifty5/dobj.py b/nifty5/dobj.py
index 47dc22979..9c56da4bc 100644
--- a/nifty5/dobj.py
+++ b/nifty5/dobj.py
@@ -15,11 +15,16 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
-try:
-    from mpi4py import MPI
-    if MPI.COMM_WORLD.Get_size() == 1:
-        from .data_objects.numpy_do import *
-    else:
-        from .data_objects.distributed_do import *
-except ImportError:
+_always_use_numpy_do = True
+
+if _always_use_numpy_do:
     from .data_objects.numpy_do import *
+else:
+    try:
+        from mpi4py import MPI
+        if MPI.COMM_WORLD.Get_size() == 1:
+            from .data_objects.numpy_do import *
+        else:
+            from .data_objects.distributed_do import *
+    except ImportError:
+        from .data_objects.numpy_do import *
diff --git a/nifty5/minimization/metric_gaussian_kl.py b/nifty5/minimization/metric_gaussian_kl.py
index c42844a36..277aec00a 100644
--- a/nifty5/minimization/metric_gaussian_kl.py
+++ b/nifty5/minimization/metric_gaussian_kl.py
@@ -148,3 +148,130 @@ class MetricGaussianKL(Energy):
     def __repr__(self):
         return 'KL ({} samples):\n'.format(len(
             self._samples)) + utilities.indent(self._hamiltonian.__repr__())
+
+
+from mpi4py import MPI
+import numpy as np
+from ..field import Field
+from ..multi_field import MultiField
+_comm = MPI.COMM_WORLD
+ntask = _comm.Get_size()
+rank = _comm.Get_rank()
+master = (rank == 0)
+
+
+def _shareRange(nwork, nshares, myshare):
+    nbase = nwork//nshares
+    additional = nwork % nshares
+    lo = myshare*nbase + min(myshare, additional)
+    hi = lo + nbase + int(myshare < additional)
+    return lo, hi
+
+
+def np_allreduce_sum(arr):
+    res = np.empty_like(arr)
+    _comm.Allreduce(arr, res, MPI.SUM)
+    return res
+
+
+def allreduce_sum_field(fld):
+    if isinstance(fld, Field):
+        return Field.from_local_data(fld.domain,
+                                     np_allreduce_sum(fld.local_data))
+    res = tuple(
+        Field.from_local_data(f.domain, np_allreduce_sum(f.local_data))
+        for f in fld.values())
+    return MultiField(fld.domain, res)
+
+
+class MetricGaussianKL_MPI(Energy):
+    def __init__(self, mean, hamiltonian, n_samples, constants=[],
+                 point_estimates=[], mirror_samples=False,
+                 _samples=None):
+        super(MetricGaussianKL_MPI, self).__init__(mean)
+
+        if not isinstance(hamiltonian, StandardHamiltonian):
+            raise TypeError
+        if hamiltonian.domain is not mean.domain:
+            raise ValueError
+        if not isinstance(n_samples, int):
+            raise TypeError
+        self._constants = list(constants)
+        self._point_estimates = list(point_estimates)
+        if not isinstance(mirror_samples, bool):
+            raise TypeError
+
+        self._hamiltonian = hamiltonian
+
+        if _samples is None:
+            lo, hi = _shareRange(n_samples, ntask, rank)
+            met = hamiltonian(Linearization.make_partial_var(
+                mean, point_estimates, True)).metric
+            _samples = []
+            for i in range(lo, hi):
+                np.random.seed(i)
+                _samples.append(met.draw_sample(from_inverse=True))
+            if mirror_samples:
+                _samples += [-s for s in _samples]
+                n_samples *= 2
+            _samples = tuple(_samples)
+        self._samples = _samples
+        self._n_samples = n_samples
+        self._lin = Linearization.make_partial_var(mean, constants)
+        v, g = None, None
+        if len(self._samples) == 0:  # hack if there are too many MPI tasks
+            tmp = self._hamiltonian(self._lin)
+            v = 0. * tmp.val.local_data[()]
+            g = 0. * tmp.gradient
+        else:
+            for s in self._samples:
+                tmp = self._hamiltonian(self._lin+s)
+                if v is None:
+                    v = tmp.val.local_data[()]
+                    g = tmp.gradient
+                else:
+                    v += tmp.val.local_data[()]
+                    g = g + tmp.gradient
+        self._val = np_allreduce_sum(v) / self._n_samples
+        self._grad = allreduce_sum_field(g) / self._n_samples
+        self._metric = None
+
+    def at(self, position):
+        return MetricGaussianKL_MPI(
+            position, self._hamiltonian, self._n_samples, self._constants,
+            self._point_estimates, _samples=self._samples)
+
+    @property
+    def value(self):
+        return self._val
+
+    @property
+    def gradient(self):
+        return self._grad
+
+    def _get_metric(self):
+        lin = self._lin.with_want_metric()
+        if self._metric is None:
+            if len(self._samples) == 0:  # hack if there are too many MPI tasks
+                self._metric = self._hamiltonian(lin).metric.scale(0.)
+            else:
+                mymap = map(lambda v: self._hamiltonian(lin+v).metric,
+                            self._samples)
+                self._metric = utilities.my_sum(mymap)
+                self._metric = self._metric.scale(1./self._n_samples)
+
+    def apply_metric(self, x):
+        self._get_metric()
+        return allreduce_sum_field(self._metric(x))
+
+    @property
+    def metric(self):
+        if ntask > 1:
+            raise ValueError("not supported when MPI is active")
+        return self._metric
+
+    @property
+    def samples(self):
+        res = _comm.allgather(self._samples)
+        res = [item for sublist in res for item in sublist]
+        return res
-- 
GitLab