From cbc3b1659cb43c6782db36abba2d9a35b1db0139 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Fri, 22 Feb 2019 13:17:58 +0100 Subject: [PATCH] re-introduce mpi_experiments --- nifty5/__init__.py | 2 +- nifty5/dobj.py | 19 ++-- nifty5/minimization/metric_gaussian_kl.py | 127 ++++++++++++++++++++++ 3 files changed, 140 insertions(+), 8 deletions(-) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 2cb37c085..66a7c17ab 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -69,7 +69,7 @@ from .minimization.scipy_minimizer import L_BFGS_B from .minimization.energy import Energy from .minimization.quadratic_energy import QuadraticEnergy from .minimization.energy_adapter import EnergyAdapter -from .minimization.metric_gaussian_kl import MetricGaussianKL +from .minimization.metric_gaussian_kl import MetricGaussianKL, MetricGaussianKL_MPI from .sugar import * from .plot import Plot diff --git a/nifty5/dobj.py b/nifty5/dobj.py index 47dc22979..9c56da4bc 100644 --- a/nifty5/dobj.py +++ b/nifty5/dobj.py @@ -15,11 +15,16 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -try: - from mpi4py import MPI - if MPI.COMM_WORLD.Get_size() == 1: - from .data_objects.numpy_do import * - else: - from .data_objects.distributed_do import * -except ImportError: +_always_use_numpy_do = True + +if _always_use_numpy_do: from .data_objects.numpy_do import * +else: + try: + from mpi4py import MPI + if MPI.COMM_WORLD.Get_size() == 1: + from .data_objects.numpy_do import * + else: + from .data_objects.distributed_do import * + except ImportError: + from .data_objects.numpy_do import * diff --git a/nifty5/minimization/metric_gaussian_kl.py b/nifty5/minimization/metric_gaussian_kl.py index c42844a36..277aec00a 100644 --- a/nifty5/minimization/metric_gaussian_kl.py +++ b/nifty5/minimization/metric_gaussian_kl.py @@ -148,3 +148,130 @@ class MetricGaussianKL(Energy): def __repr__(self): return 'KL ({} samples):\n'.format(len( self._samples)) + utilities.indent(self._hamiltonian.__repr__()) + + +from mpi4py import MPI +import numpy as np +from ..field import Field +from ..multi_field import MultiField +_comm = MPI.COMM_WORLD +ntask = _comm.Get_size() +rank = _comm.Get_rank() +master = (rank == 0) + + +def _shareRange(nwork, nshares, myshare): + nbase = nwork//nshares + additional = nwork % nshares + lo = myshare*nbase + min(myshare, additional) + hi = lo + nbase + int(myshare < additional) + return lo, hi + + +def np_allreduce_sum(arr): + res = np.empty_like(arr) + _comm.Allreduce(arr, res, MPI.SUM) + return res + + +def allreduce_sum_field(fld): + if isinstance(fld, Field): + return Field.from_local_data(fld.domain, + np_allreduce_sum(fld.local_data)) + res = tuple( + Field.from_local_data(f.domain, np_allreduce_sum(f.local_data)) + for f in fld.values()) + return MultiField(fld.domain, res) + + +class MetricGaussianKL_MPI(Energy): + def __init__(self, mean, hamiltonian, n_samples, constants=[], + point_estimates=[], mirror_samples=False, + _samples=None): + super(MetricGaussianKL_MPI, self).__init__(mean) + + if not isinstance(hamiltonian, StandardHamiltonian): + raise TypeError + if hamiltonian.domain is not mean.domain: + raise ValueError + if not isinstance(n_samples, int): + raise TypeError + self._constants = list(constants) + self._point_estimates = list(point_estimates) + if not isinstance(mirror_samples, bool): + raise TypeError + + self._hamiltonian = hamiltonian + + if _samples is None: + lo, hi = _shareRange(n_samples, ntask, rank) + met = hamiltonian(Linearization.make_partial_var( + mean, point_estimates, True)).metric + _samples = [] + for i in range(lo, hi): + np.random.seed(i) + _samples.append(met.draw_sample(from_inverse=True)) + if mirror_samples: + _samples += [-s for s in _samples] + n_samples *= 2 + _samples = tuple(_samples) + self._samples = _samples + self._n_samples = n_samples + self._lin = Linearization.make_partial_var(mean, constants) + v, g = None, None + if len(self._samples) == 0: # hack if there are too many MPI tasks + tmp = self._hamiltonian(self._lin) + v = 0. * tmp.val.local_data[()] + g = 0. * tmp.gradient + else: + for s in self._samples: + tmp = self._hamiltonian(self._lin+s) + if v is None: + v = tmp.val.local_data[()] + g = tmp.gradient + else: + v += tmp.val.local_data[()] + g = g + tmp.gradient + self._val = np_allreduce_sum(v) / self._n_samples + self._grad = allreduce_sum_field(g) / self._n_samples + self._metric = None + + def at(self, position): + return MetricGaussianKL_MPI( + position, self._hamiltonian, self._n_samples, self._constants, + self._point_estimates, _samples=self._samples) + + @property + def value(self): + return self._val + + @property + def gradient(self): + return self._grad + + def _get_metric(self): + lin = self._lin.with_want_metric() + if self._metric is None: + if len(self._samples) == 0: # hack if there are too many MPI tasks + self._metric = self._hamiltonian(lin).metric.scale(0.) + else: + mymap = map(lambda v: self._hamiltonian(lin+v).metric, + self._samples) + self._metric = utilities.my_sum(mymap) + self._metric = self._metric.scale(1./self._n_samples) + + def apply_metric(self, x): + self._get_metric() + return allreduce_sum_field(self._metric(x)) + + @property + def metric(self): + if ntask > 1: + raise ValueError("not supported when MPI is active") + return self._metric + + @property + def samples(self): + res = _comm.allgather(self._samples) + res = [item for sublist in res for item in sublist] + return res -- GitLab