Commit fd4352e0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

working version, but probably still wasteful

parent cda0e764
......@@ -36,27 +36,6 @@ def _shareRange(nwork, nshares, myshare):
return lo, hi
def _np_allreduce_sum(comm, arr):
if comm is None:
return arr
from mpi4py import MPI
arr = np.array(arr)
res = np.empty_like(arr)
comm.Allreduce(arr, res, MPI.SUM)
return res
def _allreduce_sum_field(comm, fld):
if comm is None:
return fld
if isinstance(fld, Field):
return Field(fld.domain, _np_allreduce_sum(comm, fld.val))
res = tuple(
Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values())
return MultiField(fld.domain, res)
class _KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
......@@ -185,26 +164,15 @@ class MetricGaussianKL(Energy):
raise ValueError("# of samples mismatch")
self._local_samples = _local_samples
self._lin = Linearization.make_partial_var(mean, self._constants)
v, g = None, None
if len(self._local_samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.val
g = 0. * tmp.gradient
for s in self._local_samples:
tmp = self._hamiltonian(self._lin+s)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s)
if v is None:
v = tmp.val.val_rw()
g = tmp.gradient
v += tmp.val.val
g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
v, g = [], []
for s in self._locsamp:
tmp = self._hamiltonian(self._lin+s)
self._val = self._sumup(v)[()]/self._n_eff_samples
if np.isnan(self._val) and self._mitigate_nans:
self._val = np.inf
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._grad = self._sumup(g)/self._n_eff_samples
self._metric = None
def at(self, position):
......@@ -237,8 +205,11 @@ class MetricGaussianKL(Energy):
self._metric = unscaled_metric.scale(1./self._n_eff_samples)
def apply_metric(self, x):
return _allreduce_sum_field(self._comm, self._metric(x))
lin = self._lin.with_want_metric()
res = []
for s in self._locsamp:
return self._sumup(res)/self._n_eff_samples
def metric(self):
......@@ -263,15 +234,48 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
yield -s
def _sumup(self, obj):
res = None
if self._comm is None:
for o in obj:
res = o if res is None else res + o
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
iloc = i-self._lo
if self._mirror_samples:
o = obj[2*iloc] if rank == itask else None
o = self._comm.bcast(o, root=itask)
res = o if res is None else res + o
o = obj[2*iloc+1] if rank == itask else None
o = self._comm.bcast(o, root=itask)
res = o if res is None else res + o
o = obj[iloc] if rank == itask else None
o = self._comm.bcast(o, root=itask)
res = o if res is None else res + o
return res
def _locsamp(self):
for s in self._local_samples:
yield s
if self._mirror_samples:
yield -s
def _metric_sample(self, from_inverse=False):
if from_inverse:
raise NotImplementedError()
lin = self._lin.with_want_metric()
samp = full(self._hamiltonian.domain, 0.)
samp = []
sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._local_samples):
with random.Context(sseq[self._lo+i]):
samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False)
if self._mirror_samples:
samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
return _allreduce_sum_field(self._comm, samp)/self._n_eff_samples
return self._sumup(samp)/self._n_eff_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment