Commit 72ebb333 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve utilities

parent ca67aa1e
Pipeline #76320 passed with stages
in 23 minutes and 34 seconds
......@@ -126,14 +126,9 @@ class MetricGaussianKL(Energy):
self._hamiltonian = hamiltonian
self._n_samples = int(n_samples)
if comm is not None:
self._comm = comm
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
else:
self._comm = None
self._lo, self._hi = 0, self._n_samples
self._comm = comm
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
......@@ -202,14 +197,13 @@ class MetricGaussianKL(Energy):
@property
def samples(self):
if self._comm is None:
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
if ntask == 1:
for s in self._local_samples:
yield s
if self._mirror_samples:
yield -s
else:
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
......
......@@ -277,6 +277,16 @@ def shareRange(nwork, nshares, myshare):
return lo, hi
def get_MPI_params_from_comm(comm):
if comm is None:
return 1, 0, True
size = comm.Get_size()
rank = comm.Get_rank()
return size, rank, rank == 0
def get_MPI_params():
"""Returns basic information about the MPI setup of the running script.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment