diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py
index cd2f54c7da396fcf22ce9f723b8e21ce0a71dac3..390b791e0c390e45b6a926ac71fa375a6cc78e36 100644
--- a/demos/getting_started_3.py
+++ b/demos/getting_started_3.py
@@ -89,7 +89,7 @@ if __name__ == '__main__':
     plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png")
 
     # number of samples used to estimate the KL
-    N_samples = 20
+    N_samples = 1
     for i in range(2):
 #        KL = ift.KL_Energy(position, H, N_samples)
         KL = ift.KL_Energy_MPI(position, H, N_samples, want_metric=True)
diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py
index b23aa2b83273120eab47a82519dc0f2d37a0d4f0..71d0ee6903abe0e6e74005666309444dae51e5b2 100644
--- a/nifty5/minimization/kl_energy.py
+++ b/nifty5/minimization/kl_energy.py
@@ -54,12 +54,6 @@ class KL_Energy_MPI(Energy):
         self._nsamp = nsamp
         self._constants = constants
         self._want_metric = want_metric
-        if nsamp < ntask:
-            # FIXME We need a better solution here. It is probably not good if
-            # the script just dies. Can we proceed anyways?
-            print('Number of samples: {}, number of MPI tasks: {}'.format(
-                nsamp, ntask))
-            raise RuntimeError('Cannot use more tasks than samples.')
         if _samples is None:
             lo, hi = _shareRange(nsamp, ntask, rank)
             met = h(Linearization.make_var(position, True)).metric
@@ -70,7 +64,10 @@ class KL_Energy_MPI(Energy):
         self._samples = tuple(_samples)
         self._lin = Linearization.make_partial_var(position, constants,
                                                    want_metric)
-        mymap = map(lambda v: self._h(self._lin + v), self._samples)
+        if len(self._samples) == 0:  # hack if there are too many MPI tasks
+            mymap = map(lambda v: 0*self._h(v), (self._lin,))
+        else:
+            mymap = map(lambda v: self._h(self._lin + v), self._samples)
         tmp = utilities.my_sum(mymap)*(1./self._nsamp)
         self._val = np_allreduce_sum(tmp.val.local_data)[()]
         self._grad = allreduce_sum_field(tmp.gradient)