From 74fdeb6b1918eeb48370f25ee1f6229f8d6e47eb Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Wed, 29 Aug 2018 21:39:35 +0200
Subject: [PATCH] simplify

---
 demos/getting_started_3.py       |  2 +-
 nifty5/minimization/kl_energy.py | 11 ++++-------
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py
index cd2f54c7d..390b791e0 100644
--- a/demos/getting_started_3.py
+++ b/demos/getting_started_3.py
@@ -89,7 +89,7 @@ if __name__ == '__main__':
     plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png")
 
     # number of samples used to estimate the KL
-    N_samples = 20
+    N_samples = 1
     for i in range(2):
 #        KL = ift.KL_Energy(position, H, N_samples)
         KL = ift.KL_Energy_MPI(position, H, N_samples, want_metric=True)
diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py
index b23aa2b83..71d0ee690 100644
--- a/nifty5/minimization/kl_energy.py
+++ b/nifty5/minimization/kl_energy.py
@@ -54,12 +54,6 @@ class KL_Energy_MPI(Energy):
         self._nsamp = nsamp
         self._constants = constants
         self._want_metric = want_metric
-        if nsamp < ntask:
-            # FIXME We need a better solution here. It is probably not good if
-            # the script just dies. Can we proceed anyways?
-            print('Number of samples: {}, number of MPI tasks: {}'.format(
-                nsamp, ntask))
-            raise RuntimeError('Cannot use more tasks than samples.')
         if _samples is None:
             lo, hi = _shareRange(nsamp, ntask, rank)
             met = h(Linearization.make_var(position, True)).metric
@@ -70,7 +64,10 @@ class KL_Energy_MPI(Energy):
         self._samples = tuple(_samples)
         self._lin = Linearization.make_partial_var(position, constants,
                                                    want_metric)
-        mymap = map(lambda v: self._h(self._lin + v), self._samples)
+        if len(self._samples) == 0:  # hack if there are too many MPI tasks
+            mymap = map(lambda v: 0*self._h(v), (self._lin,))
+        else:
+            mymap = map(lambda v: self._h(self._lin + v), self._samples)
         tmp = utilities.my_sum(mymap)*(1./self._nsamp)
         self._val = np_allreduce_sum(tmp.val.local_data)[()]
         self._grad = allreduce_sum_field(tmp.gradient)
-- 
GitLab