Commit 0aa58a2d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

provide communicator instead of use_mpi flag

parent f629a706
......@@ -110,10 +110,9 @@ class MetricGaussianKL(Energy):
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
use_mpi : bool
whether MPI should be used.
If MPI is enabled, samples will be distributed as evenly as possible
across MPI.COMM_WORLD. If `mirror_samples` is set, then a sample and
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
lh_sampling_dtype : type
Determines which dtype in data space shall be used for drawing samples
......@@ -139,7 +138,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, use_mpi=False, _samples=None,
napprox=0, comm=None, _samples=None,
super(MetricGaussianKL, self).__init__(mean)
......@@ -157,10 +156,8 @@ class MetricGaussianKL(Energy):
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
self._use_mpi = bool(use_mpi)
if self._use_mpi:
from mpi4py import MPI
self._comm = MPI.COMM_WORLD
if comm is not None:
self._comm = comm
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
self._lo, self._hi = _shareRange(self._n_samples, ntask, rank)
......@@ -215,7 +212,7 @@ class MetricGaussianKL(Energy):
def at(self, position):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, use_mpi=self._use_mpi,
self._point_estimates, self._mirror_samples, comm=self._comm,
_samples=self._samples, lh_sampling_dtype=self._sampdt)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment