From 6f19b45d92bcdbec2856b60587b61a5c16e2cf13 Mon Sep 17 00:00:00 2001
From: Theo Steininger <theo.steininger@ultimanet.de>
Date: Sat, 28 Jan 2017 02:40:25 +0100
Subject: [PATCH] bincount in _slicing_distributor now tries to preserve a
 global_type distribution strategy if possible.

---
 d2o/distributor_factory.py | 28 +++++++++++++++++++++++++---
 d2o/version.py             |  2 +-
 2 files changed, 26 insertions(+), 4 deletions(-)

diff --git a/d2o/distributor_factory.py b/d2o/distributor_factory.py
index 1105925..1c0bb9b 100644
--- a/d2o/distributor_factory.py
+++ b/d2o/distributor_factory.py
@@ -1656,13 +1656,35 @@ class _slicing_distributor(distributor):
             result_object = obj.copy_empty(global_shape=global_counts.shape,
                                            dtype=global_counts.dtype,
                                            distribution_strategy='not')
+            result_object.set_local_data(global_counts, copy=False)
+
         else:
             global_counts = local_counts
-            result_object = obj.copy_empty(local_shape=global_counts.shape,
+            global_shape = list(global_counts.shape)
+            global_shape[0] = self.global_shape[0]
+
+            # Try to preserve the distribution_strategy if possible.
+            # Choose the local distribution_strategy if...
+            # -> if the distribution_strategy was local in the beginning
+            # -> if the slicing of the first axis changes even though
+            #    the bincounting wasn't performed on it. This can be the case
+            #    if axis=0 is extremely small and collapsing axis=1,2,3...
+            #    forces fftw to use a different slicing.
+            # Both cases are triggered by an ValueError exception
+
+            try:
+                if self.distribution_strategy not in STRATEGIES['global']:
+                    raise ValueError
+
+                result_object = obj.copy_empty(global_shape=global_shape,
+                                               dtype=global_counts.dtype)
+                result_object.set_local_data(global_counts, copy=False)
+
+            except ValueError:
+                result_object = obj.copy_empty(
+                                           local_shape=global_counts.shape,
                                            dtype=global_counts.dtype,
                                            distribution_strategy='freeform')
-
-        result_object.set_local_data(global_counts, copy=False)
         return result_object
 
 
diff --git a/d2o/version.py b/d2o/version.py
index a0bd678..16c79f9 100644
--- a/d2o/version.py
+++ b/d2o/version.py
@@ -20,4 +20,4 @@
 # 1) we don't load dependencies by storing it in __init__.py
 # 2) we can import it in setup.py for the same reason
 # 3) we can import it into your module module
-__version__ = '1.0.6'
+__version__ = '1.0.7'
-- 
GitLab