Skip to content
Snippets Groups Projects
Commit 0aa58a2d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

provide communicator instead of use_mpi flag

parent f629a706
No related branches found
No related tags found
1 merge request!428Switch to new numpy random generators
...@@ -110,10 +110,9 @@ class MetricGaussianKL(Energy): ...@@ -110,10 +110,9 @@ class MetricGaussianKL(Energy):
napprox : int napprox : int
Number of samples for computing preconditioner for sampling. No Number of samples for computing preconditioner for sampling. No
preconditioning is done by default. preconditioning is done by default.
use_mpi : bool comm : MPI communicator or None
whether MPI should be used. If not None, samples will be distributed as evenly as possible
If MPI is enabled, samples will be distributed as evenly as possible across this communicator. If `mirror_samples` is set, then a sample and
across MPI.COMM_WORLD. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task. its mirror image will always reside on the same task.
lh_sampling_dtype : type lh_sampling_dtype : type
Determines which dtype in data space shall be used for drawing samples Determines which dtype in data space shall be used for drawing samples
...@@ -139,7 +138,7 @@ class MetricGaussianKL(Energy): ...@@ -139,7 +138,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False, point_estimates=[], mirror_samples=False,
napprox=0, use_mpi=False, _samples=None, napprox=0, comm=None, _samples=None,
lh_sampling_dtype=np.float64): lh_sampling_dtype=np.float64):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
...@@ -157,10 +156,8 @@ class MetricGaussianKL(Energy): ...@@ -157,10 +156,8 @@ class MetricGaussianKL(Energy):
self._hamiltonian = hamiltonian self._hamiltonian = hamiltonian
self._n_samples = int(n_samples) self._n_samples = int(n_samples)
self._use_mpi = bool(use_mpi) if comm is not None:
if self._use_mpi: self._comm = comm
from mpi4py import MPI
self._comm = MPI.COMM_WORLD
ntask = self._comm.Get_size() ntask = self._comm.Get_size()
rank = self._comm.Get_rank() rank = self._comm.Get_rank()
self._lo, self._hi = _shareRange(self._n_samples, ntask, rank) self._lo, self._hi = _shareRange(self._n_samples, ntask, rank)
...@@ -215,7 +212,7 @@ class MetricGaussianKL(Energy): ...@@ -215,7 +212,7 @@ class MetricGaussianKL(Energy):
def at(self, position): def at(self, position):
return MetricGaussianKL( return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants, position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, use_mpi=self._use_mpi, self._point_estimates, self._mirror_samples, comm=self._comm,
_samples=self._samples, lh_sampling_dtype=self._sampdt) _samples=self._samples, lh_sampling_dtype=self._sampdt)
@property @property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment