Skip to content
Snippets Groups Projects
Commit cbc3b165 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

re-introduce mpi_experiments

parent 6fcd6fb6
No related branches found
No related tags found
1 merge request!298re-introduce mpi_experiments
Pipeline #44063 passed
......@@ -69,7 +69,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter
from .minimization.metric_gaussian_kl import MetricGaussianKL
from .minimization.metric_gaussian_kl import MetricGaussianKL, MetricGaussianKL_MPI
from .sugar import *
from .plot import Plot
......
......@@ -15,11 +15,16 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
try:
from mpi4py import MPI
if MPI.COMM_WORLD.Get_size() == 1:
from .data_objects.numpy_do import *
else:
from .data_objects.distributed_do import *
except ImportError:
_always_use_numpy_do = True
if _always_use_numpy_do:
from .data_objects.numpy_do import *
else:
try:
from mpi4py import MPI
if MPI.COMM_WORLD.Get_size() == 1:
from .data_objects.numpy_do import *
else:
from .data_objects.distributed_do import *
except ImportError:
from .data_objects.numpy_do import *
......@@ -148,3 +148,130 @@ class MetricGaussianKL(Energy):
def __repr__(self):
return 'KL ({} samples):\n'.format(len(
self._samples)) + utilities.indent(self._hamiltonian.__repr__())
from mpi4py import MPI
import numpy as np
from ..field import Field
from ..multi_field import MultiField
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
rank = _comm.Get_rank()
master = (rank == 0)
def _shareRange(nwork, nshares, myshare):
nbase = nwork//nshares
additional = nwork % nshares
lo = myshare*nbase + min(myshare, additional)
hi = lo + nbase + int(myshare < additional)
return lo, hi
def np_allreduce_sum(arr):
res = np.empty_like(arr)
_comm.Allreduce(arr, res, MPI.SUM)
return res
def allreduce_sum_field(fld):
if isinstance(fld, Field):
return Field.from_local_data(fld.domain,
np_allreduce_sum(fld.local_data))
res = tuple(
Field.from_local_data(f.domain, np_allreduce_sum(f.local_data))
for f in fld.values())
return MultiField(fld.domain, res)
class MetricGaussianKL_MPI(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
_samples=None):
super(MetricGaussianKL_MPI, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not mean.domain:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
self._constants = list(constants)
self._point_estimates = list(point_estimates)
if not isinstance(mirror_samples, bool):
raise TypeError
self._hamiltonian = hamiltonian
if _samples is None:
lo, hi = _shareRange(n_samples, ntask, rank)
met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric
_samples = []
for i in range(lo, hi):
np.random.seed(i)
_samples.append(met.draw_sample(from_inverse=True))
if mirror_samples:
_samples += [-s for s in _samples]
n_samples *= 2
_samples = tuple(_samples)
self._samples = _samples
self._n_samples = n_samples
self._lin = Linearization.make_partial_var(mean, constants)
v, g = None, None
if len(self._samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.local_data[()]
g = 0. * tmp.gradient
else:
for s in self._samples:
tmp = self._hamiltonian(self._lin+s)
if v is None:
v = tmp.val.local_data[()]
g = tmp.gradient
else:
v += tmp.val.local_data[()]
g = g + tmp.gradient
self._val = np_allreduce_sum(v) / self._n_samples
self._grad = allreduce_sum_field(g) / self._n_samples
self._metric = None
def at(self, position):
return MetricGaussianKL_MPI(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, _samples=self._samples)
@property
def value(self):
return self._val
@property
def gradient(self):
return self._grad
def _get_metric(self):
lin = self._lin.with_want_metric()
if self._metric is None:
if len(self._samples) == 0: # hack if there are too many MPI tasks
self._metric = self._hamiltonian(lin).metric.scale(0.)
else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self._metric = utilities.my_sum(mymap)
self._metric = self._metric.scale(1./self._n_samples)
def apply_metric(self, x):
self._get_metric()
return allreduce_sum_field(self._metric(x))
@property
def metric(self):
if ntask > 1:
raise ValueError("not supported when MPI is active")
return self._metric
@property
def samples(self):
res = _comm.allgather(self._samples)
res = [item for sublist in res for item in sublist]
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment