Commit 57511c77 authored by Martin Reinecke's avatar Martin Reinecke

tentative fix

parent 05784024
Pipeline #44090 passed with stages
in 8 minutes and 9 seconds
......@@ -92,18 +92,18 @@ class MetricGaussianKL_MPI(Energy):
v, g = None, None
if len(self._samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.local_data[()]
v = 0. * tmp.val.local_data
g = 0. * tmp.gradient
else:
for s in self._samples:
tmp = self._hamiltonian(self._lin+s)
if v is None:
v = tmp.val.local_data[()]
v = tmp.val.local_data.copy()
g = tmp.gradient
else:
v += tmp.val.local_data[()]
v += tmp.val.local_data
g = g + tmp.gradient
self._val = np_allreduce_sum(v) / self._n_samples
self._val = np_allreduce_sum(v)[()] / self._n_samples
self._grad = allreduce_sum_field(g) / self._n_samples
self._metric = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment