From 74fdeb6b1918eeb48370f25ee1f6229f8d6e47eb Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Wed, 29 Aug 2018 21:39:35 +0200 Subject: [PATCH] simplify --- demos/getting_started_3.py | 2 +- nifty5/minimization/kl_energy.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index cd2f54c7d..390b791e0 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -89,7 +89,7 @@ if __name__ == '__main__': plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") # number of samples used to estimate the KL - N_samples = 20 + N_samples = 1 for i in range(2): # KL = ift.KL_Energy(position, H, N_samples) KL = ift.KL_Energy_MPI(position, H, N_samples, want_metric=True) diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py index b23aa2b83..71d0ee690 100644 --- a/nifty5/minimization/kl_energy.py +++ b/nifty5/minimization/kl_energy.py @@ -54,12 +54,6 @@ class KL_Energy_MPI(Energy): self._nsamp = nsamp self._constants = constants self._want_metric = want_metric - if nsamp < ntask: - # FIXME We need a better solution here. It is probably not good if - # the script just dies. Can we proceed anyways? - print('Number of samples: {}, number of MPI tasks: {}'.format( - nsamp, ntask)) - raise RuntimeError('Cannot use more tasks than samples.') if _samples is None: lo, hi = _shareRange(nsamp, ntask, rank) met = h(Linearization.make_var(position, True)).metric @@ -70,7 +64,10 @@ class KL_Energy_MPI(Energy): self._samples = tuple(_samples) self._lin = Linearization.make_partial_var(position, constants, want_metric) - mymap = map(lambda v: self._h(self._lin + v), self._samples) + if len(self._samples) == 0: # hack if there are too many MPI tasks + mymap = map(lambda v: 0*self._h(v), (self._lin,)) + else: + mymap = map(lambda v: self._h(self._lin + v), self._samples) tmp = utilities.my_sum(mymap)*(1./self._nsamp) self._val = np_allreduce_sum(tmp.val.local_data)[()] self._grad = allreduce_sum_field(tmp.gradient) -- GitLab