diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index cd2f54c7da396fcf22ce9f723b8e21ce0a71dac3..390b791e0c390e45b6a926ac71fa375a6cc78e36 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -89,7 +89,7 @@ if __name__ == '__main__': plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") # number of samples used to estimate the KL - N_samples = 20 + N_samples = 1 for i in range(2): # KL = ift.KL_Energy(position, H, N_samples) KL = ift.KL_Energy_MPI(position, H, N_samples, want_metric=True) diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py index b23aa2b83273120eab47a82519dc0f2d37a0d4f0..71d0ee6903abe0e6e74005666309444dae51e5b2 100644 --- a/nifty5/minimization/kl_energy.py +++ b/nifty5/minimization/kl_energy.py @@ -54,12 +54,6 @@ class KL_Energy_MPI(Energy): self._nsamp = nsamp self._constants = constants self._want_metric = want_metric - if nsamp < ntask: - # FIXME We need a better solution here. It is probably not good if - # the script just dies. Can we proceed anyways? - print('Number of samples: {}, number of MPI tasks: {}'.format( - nsamp, ntask)) - raise RuntimeError('Cannot use more tasks than samples.') if _samples is None: lo, hi = _shareRange(nsamp, ntask, rank) met = h(Linearization.make_var(position, True)).metric @@ -70,7 +64,10 @@ class KL_Energy_MPI(Energy): self._samples = tuple(_samples) self._lin = Linearization.make_partial_var(position, constants, want_metric) - mymap = map(lambda v: self._h(self._lin + v), self._samples) + if len(self._samples) == 0: # hack if there are too many MPI tasks + mymap = map(lambda v: 0*self._h(v), (self._lin,)) + else: + mymap = map(lambda v: self._h(self._lin + v), self._samples) tmp = utilities.my_sum(mymap)*(1./self._nsamp) self._val = np_allreduce_sum(tmp.val.local_data)[()] self._grad = allreduce_sum_field(tmp.gradient)