distributor_factory.py 84.4 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2

3
4
import numbers

ultimanet's avatar
ultimanet committed
5
import numpy as np
6

theos's avatar
theos committed
7
8
9
from nifty.keepers import about,\
                          global_configuration as gc,\
                          global_dependency_injector as gdi
10

theos's avatar
theos committed
11
from distributed_data_object import distributed_data_object
12

theos's avatar
theos committed
13
14
15
16
from d2o_iter import d2o_slicing_iter,\
                     d2o_not_iter
from d2o_librarian import d2o_librarian
from dtype_converter import dtype_converter
17
18
from cast_axis_to_tuple import cast_axis_to_tuple
from translate_to_mpi_operator import op_translate_dict
19

theos's avatar
theos committed
20
from strategies import STRATEGIES
21

theos's avatar
theos committed
22
23
24
MPI = gdi[gc['mpi_module']]
h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
ultimanet's avatar
ultimanet committed
25

26

27
class _distributor_factory(object):
28

29
30
    def __init__(self):
        self.distributor_store = {}
31

32
    def parse_kwargs(self, distribution_strategy, comm,
33
34
35
                     global_data=None, global_shape=None,
                     local_data=None, local_shape=None,
                     alias=None, path=None,
36
37
38
39
40
41
42
43
44
45
46
47
                     dtype=None, skip_parsing=False, **kwargs):

        if skip_parsing:
            return_dict = {'comm': comm,
                           'dtype': dtype,
                           'name': distribution_strategy
                           }
            if distribution_strategy in STRATEGIES['global']:
                return_dict['global_shape'] = global_shape
            elif distribution_strategy in STRATEGIES['local']:
                return_dict['local_shape'] = local_shape
            return return_dict
Ultima's avatar
Ultima committed
48

49
        return_dict = {}
50
51
52
53
54

        expensive_checks = gc['d2o_init_checks']

        # Parse the MPI communicator
        if comm is None:
Ultima's avatar
Ultima committed
55
            raise ValueError(about._errors.cstring(
56
57
58
59
60
61
62
63
64
65
66
                "ERROR: The distributor needs MPI-communicator object comm!"))
        else:
            return_dict['comm'] = comm

        if expensive_checks:
            # Check that all nodes got the same distribution_strategy
            strat_list = comm.allgather(distribution_strategy)
            if all(x == strat_list[0] for x in strat_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The distribution-strategy must be the same on " +
                    "all nodes!"))
67

68
        # Check for an hdf5 file and open it if given
69
        if 'h5py' in gdi and alias is not None:
70
71
72
            # set file path
            file_path = path if (path is not None) else alias
            # open hdf5 file
73
            if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
Ultima's avatar
Ultima committed
74
75
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
76
77
78
                f = h5py.File(file_path, 'r')
            # open alias in file
            dset = f[alias]
Ultima's avatar
Ultima committed
79
80
81
        else:
            dset = None

82
        # Parse the datatype
Ultima's avatar
Ultima committed
83
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
84
                (dset is not None):
85
            dtype = dset.dtype
86
87

        elif distribution_strategy in ['not', 'equal', 'fftw']:
Ultima's avatar
Ultima committed
88
            if dtype is None:
89
                if global_data is None:
Ultima's avatar
Ultima committed
90
91
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
92
                else:
Ultima's avatar
Ultima committed
93
                    try:
94
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
95
                    except(AttributeError):
96
97
98
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
99

100
        elif distribution_strategy in STRATEGIES['local']:
101
            if dtype is None:
Ultima's avatar
Ultima committed
102
103
104
                if isinstance(global_data, distributed_data_object):
                    dtype = global_data.dtype
                elif local_data is not None:
Ultima's avatar
Ultima committed
105
                    try:
106
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
107
                    except(AttributeError):
108
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
109
110
111
                else:
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
112

Ultima's avatar
Ultima committed
113
            else:
114
                dtype = np.dtype(dtype)
115
116
117
118
119
        if expensive_checks:
            dtype_list = comm.allgather(dtype)
            if all(x == dtype_list[0] for x in dtype_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The given dtype must be the same on all nodes!"))
Ultima's avatar
Ultima committed
120
        return_dict['dtype'] = dtype
121
122
123

        # Parse the shape
        # Case 1: global-type slicer
124
        if distribution_strategy in STRATEGIES['global']:
Ultima's avatar
Ultima committed
125
126
127
128
129
130
131
132
133
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
134
                    "global_shape nor hdf5 file supplied!"))
Ultima's avatar
Ultima committed
135
136
            if global_shape == ():
                raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
137
                    "ERROR: global_shape == () is not a valid shape!"))
138

139
140
141
142
143
144
145
            if expensive_checks:
                global_shape_list = comm.allgather(global_shape)
                if not all(x == global_shape_list[0]
                           for x in global_shape_list):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The global_shape must be the same on all " +
                        "nodes!"))
Ultima's avatar
Ultima committed
146
147
            return_dict['global_shape'] = global_shape

148
        # Case 2: local-type slicer
Ultima's avatar
Ultima committed
149
150
151
152
        elif distribution_strategy in ['freeform']:
            if isinstance(global_data, distributed_data_object):
                local_shape = global_data.local_shape
            elif local_data is not None and np.isscalar(local_data) == False:
Ultima's avatar
Ultima committed
153
154
155
156
157
158
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
159
                    "local_shape nor global d2o supplied!"))
Ultima's avatar
Ultima committed
160
161
            if local_shape == ():
                raise ValueError(about._errors.cstring(
162
163
                    "ERROR: local_shape == () is not a valid shape!"))

164
165
166
167
168
169
170
171
            if expensive_checks:
                local_shape_list = comm.allgather(local_shape[1:])
                cleared_set = set(local_shape_list)
                cleared_set.discard(())
                if len(cleared_set) > 1:
                    raise ValueError(about._errors.cstring(
                        "ERROR: All but the first entry of local_shape " +
                        "must be the same on all nodes!"))
Ultima's avatar
Ultima committed
172
173
            return_dict['local_shape'] = local_shape

174
        # Add the name of the distributor if needed
175
176
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
177
178

        # close the file-handle
Ultima's avatar
Ultima committed
179
180
181
        if dset is not None:
            f.close()

182
        return return_dict
183

Ultima's avatar
Ultima committed
184
    def hash_arguments(self, distribution_strategy, **kwargs):
185
        kwargs = kwargs.copy()
186

187
188
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
189

190
        if 'global_shape' in kwargs:
Ultima's avatar
Ultima committed
191
            kwargs['global_shape'] = kwargs['global_shape']
192
        if 'local_shape' in kwargs:
193
194
195
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
196

Ultima's avatar
Ultima committed
197
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
198
        kwargs['distribution_strategy'] = distribution_strategy
199

200
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
201

202
    def dictionize_np(self, x):
203
        dic = x.type.__dict__.items()
204
        if x is np.float:
205
            dic[24] = 0
206
207
            dic[29] = 0
            dic[37] = 0
208
209
        return frozenset(dic)

210
    def get_distributor(self, distribution_strategy, comm, **kwargs):
211
        # check if the distribution strategy is known
212
        if distribution_strategy not in STRATEGIES['all']:
213
            raise ValueError(about._errors.cstring(
214
                "ERROR: Unknown distribution strategy supplied."))
215
216

        # parse the kwargs
Ultima's avatar
Ultima committed
217
        parsed_kwargs = self.parse_kwargs(
218
219
220
            distribution_strategy=distribution_strategy,
            comm=comm,
            **kwargs)
221

Ultima's avatar
Ultima committed
222
223
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
224
        # check if the distributors has already been produced in the past
225
        if hashed_kwargs in self.distributor_store:
Ultima's avatar
Ultima committed
226
            return self.distributor_store[hashed_kwargs]
227
        else:
228
            # produce new distributor
229
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
230
                produced_distributor = _not_distributor(**parsed_kwargs)
231

232
233
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
234
235
                    slicer=_equal_slicer,
                    **parsed_kwargs)
236

237
238
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
239
240
                    slicer=_fftw_slicer,
                    **parsed_kwargs)
Ultima's avatar
Ultima committed
241
242
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
243
244
                    slicer=_freeform_slicer,
                    **parsed_kwargs)
245
246

            self.distributor_store[hashed_kwargs] = produced_distributor
Ultima's avatar
Ultima committed
247
            return self.distributor_store[hashed_kwargs]
248
249


250
distributor_factory = _distributor_factory()
Ultima's avatar
Ultima committed
251
252
253
254
255
256


def _infer_key_type(key):
    if key is None:
        return (None, None)
    found_boolean = False
257
    # Check which case we got:
258
    if isinstance(key, tuple) or isinstance(key, slice) or np.isscalar(key):
259
260
        # Check if there is something different in the array than
        # scalars and slices
Ultima's avatar
Ultima committed
261
262
        if isinstance(key, slice) or np.isscalar(key):
            key = [key]
263

Ultima's avatar
Ultima committed
264
265
266
        scalarQ = np.array(map(np.isscalar, key))
        sliceQ = np.array(map(lambda z: isinstance(z, slice), key))
        if np.all(scalarQ + sliceQ):
267
            found = 'slicetuple'
Ultima's avatar
Ultima committed
268
269
270
271
272
273
274
275
276
277
        else:
            found = 'indexinglist'
    elif isinstance(key, np.ndarray):
        found = 'ndarray'
        found_boolean = (key.dtype.type == np.bool_)
    elif isinstance(key, distributed_data_object):
        found = 'd2o'
        found_boolean = (key.dtype == np.bool_)
    elif isinstance(key, list):
        found = 'indexinglist'
278
279
    else:
        raise ValueError(about._errors.cstring("ERROR: Unknown keytype!"))
Ultima's avatar
Ultima committed
280
281
282
283
    return (found, found_boolean)


class distributor(object):
284
285

    def disperse_data(self, data, to_key, data_update, from_key=None,
Ultima's avatar
Ultima committed
286
                      local_keys=False, copy=True, **kwargs):
287
        # Check which keys we got:
Ultima's avatar
Ultima committed
288
289
290
        (to_found, to_found_boolean) = _infer_key_type(to_key)
        (from_found, from_found_boolean) = _infer_key_type(from_key)

291
        comm = self.comm
292
293
294
295
296
297
298
299
300
301
302
303
        if local_keys is False:
            return self._disperse_data_primitive(
                                         data=data,
                                         to_key=to_key,
                                         data_update=data_update,
                                         from_key=from_key,
                                         copy=copy,
                                         to_found=to_found,
                                         to_found_boolean=to_found_boolean,
                                         from_found=from_found,
                                         from_found_boolean=from_found_boolean,
                                         **kwargs)
304

Ultima's avatar
Ultima committed
305
        else:
306
            # assert that all to_keys are from same type
Ultima's avatar
Ultima committed
307
            to_found_list = comm.allgather(to_found)
308
            assert(all(x == to_found_list[0] for x in to_found_list))
Ultima's avatar
Ultima committed
309
            to_found_boolean_list = comm.allgather(to_found_boolean)
310
311
            assert(all(x == to_found_boolean_list[0] for x in
                       to_found_boolean_list))
Ultima's avatar
Ultima committed
312
            from_found_list = comm.allgather(from_found)
313
            assert(all(x == from_found_list[0] for x in from_found_list))
Ultima's avatar
Ultima committed
314
            from_found_boolean_list = comm.allgather(from_found_boolean)
315
316
            assert(all(x == from_found_boolean_list[0] for
                       x in from_found_boolean_list))
Ultima's avatar
Ultima committed
317

318
319
320
            # gather the local to_keys into a global to_key_list
            # Case 1: the to_keys are not distributed_data_objects
            # -> allgather does the job
Ultima's avatar
Ultima committed
321
322
            if to_found != 'd2o':
                to_key_list = comm.allgather(to_key)
323
324
325
            # Case 2: if the to_keys are distributed_data_objects, gather
            # the index of the array and build the to_key_list with help
            # from the librarian
Ultima's avatar
Ultima committed
326
327
328
329
            else:
                to_index_list = comm.allgather(to_key.index)
                to_key_list = map(lambda z: d2o_librarian[z], to_index_list)

330
            # gather the local from_keys. It is the same procedure as above
Ultima's avatar
Ultima committed
331
            if from_found != 'd2o':
332
                from_key_list = comm.allgather(from_key)
Ultima's avatar
Ultima committed
333
334
            else:
                from_index_list = comm.allgather(from_key.index)
335
336
                from_key_list = map(lambda z: d2o_librarian[z],
                                    from_index_list)
337

Ultima's avatar
Ultima committed
338
            local_data_update_is_scalar = np.isscalar(data_update)
339
            local_scalar_list = comm.allgather(local_data_update_is_scalar)
Ultima's avatar
Ultima committed
340
            for i in xrange(len(to_key_list)):
341
                if np.all(np.array(local_scalar_list) == True):
Ultima's avatar
Ultima committed
342
343
344
345
346
347
                    scalar_list = comm.allgather(data_update)
                    temp_data_update = scalar_list[i]
                elif isinstance(data_update, distributed_data_object):
                    data_update_index_list = comm.allgather(data_update.index)
                    data_update_list = map(lambda z: d2o_librarian[z],
                                           data_update_index_list)
348
                    temp_data_update = data_update_list[i]
Ultima's avatar
Ultima committed
349
                else:
350
351
                    # build a temporary freeform d2o which only contains data
                    # from node i
Ultima's avatar
Ultima committed
352
353
354
                    if comm.rank == i:
                        temp_shape = np.shape(data_update)
                        try:
355
                            temp_dtype = np.dtype(data_update.dtype)
Ultima's avatar
Ultima committed
356
                        except(TypeError):
357
                            temp_dtype = np.array(data_update).dtype
Ultima's avatar
Ultima committed
358
359
360
361
362
                    else:
                        temp_shape = None
                        temp_dtype = None
                    temp_shape = comm.bcast(temp_shape, root=i)
                    temp_dtype = comm.bcast(temp_dtype, root=i)
363

Ultima's avatar
Ultima committed
364
365
366
367
                    if comm.rank != i:
                        temp_shape = list(temp_shape)
                        temp_shape[0] = 0
                        temp_shape = tuple(temp_shape)
368
                        temp_data = np.empty(temp_shape, dtype=temp_dtype)
Ultima's avatar
Ultima committed
369
370
371
                    else:
                        temp_data = data_update
                    temp_data_update = distributed_data_object(
372
373
374
375
                                        local_data=temp_data,
                                        distribution_strategy='freeform',
                                        copy=False,
                                        comm=self.comm)
Ultima's avatar
Ultima committed
376
                # disperse the data one after another
377
378
379
380
381
382
383
384
385
386
387
                self._disperse_data_primitive(
                                      data=data,
                                      to_key=to_key_list[i],
                                      data_update=temp_data_update,
                                      from_key=from_key_list[i],
                                      copy=copy,
                                      to_found=to_found,
                                      to_found_boolean=to_found_boolean,
                                      from_found=from_found,
                                      from_found_boolean=from_found_boolean,
                                      **kwargs)
388
389
                i += 1

390

Ultima's avatar
Ultima committed
391
class _slicing_distributor(distributor):
392
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
393

Ultima's avatar
Ultima committed
394
395
        self.comm = comm
        self.distribution_strategy = name
396
        self.dtype = np.dtype(dtype)
397

theos's avatar
theos committed
398
        self._my_dtype_converter = dtype_converter
399

ultimanet's avatar
ultimanet committed
400
        if not self._my_dtype_converter.known_np_Q(self.dtype):
401
            raise TypeError(about._errors.cstring(
402
403
                "ERROR: The datatype " + str(self.dtype.__repr__()) +
                " is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
404

405
        self.mpi_dtype = self._my_dtype_converter.to_mpi(self.dtype)
406
407

        self.slicer = slicer
408
        self._local_size = self.slicer(comm=comm, **remaining_parsed_kwargs)
409
        self.local_start = self._local_size[0]
410
        self.local_end = self._local_size[1]
Ultima's avatar
Ultima committed
411
        self.global_shape = self._local_size[2]
412

413
        self.local_length = self.local_end - self.local_start
ultimanet's avatar
ultimanet committed
414
415
416
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
417
418
        comm.Allgather([np.array(self.local_dim, dtype=np.int), MPI.INT],
                       [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
419
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
420

421
422
423
424
        self.local_slice = np.array([self.local_start, self.local_end,
                                     self.local_length, self.local_dim,
                                     self.local_dim_offset],
                                    dtype=np.int)
425
        # collect all local_slices
426
427
428
        self.all_local_slices = np.empty((comm.size, 5), dtype=np.int)
        comm.Allgather([np.array((self.local_slice,), dtype=np.int), MPI.INT],
                       [self.all_local_slices, MPI.INT])
429
430

    def initialize_data(self, global_data, local_data, alias, path, hermitian,
Ultima's avatar
Ultima committed
431
                        copy, **kwargs):
432
        if 'h5py' in gdi and alias is not None:
433
            local_data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
434
            return (local_data, hermitian)
435
436

        if self.distribution_strategy in ['equal', 'fftw']:
Ultima's avatar
Ultima committed
437
            if np.isscalar(global_data):
438
                local_data = np.empty(self.local_shape, dtype=self.dtype)
439
                local_data.fill(global_data)
Ultima's avatar
Ultima committed
440
441
                hermitian = True
            else:
442
443
                local_data = self.distribute_data(data=global_data,
                                                  copy=copy)
Ultima's avatar
Ultima committed
444
        elif self.distribution_strategy in ['freeform']:
Ultima's avatar
Ultima committed
445
            if isinstance(global_data, distributed_data_object):
446
                local_data = global_data.get_local_data(copy=copy)
Ultima's avatar
Ultima committed
447
            elif np.isscalar(local_data):
448
                temp_local_data = np.empty(self.local_shape,
449
                                           dtype=self.dtype)
450
                temp_local_data.fill(local_data)
451
                local_data = temp_local_data
Ultima's avatar
Ultima committed
452
453
                hermitian = True
            elif local_data is None:
454
                local_data = np.empty(self.local_shape, dtype=self.dtype)
455
456
457
            elif isinstance(local_data, np.ndarray):
                local_data = local_data.astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
458
459
            else:
                local_data = np.array(local_data).astype(
460
                    self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
461
462
        else:
            raise TypeError(about._errors.cstring(
463
                "ERROR: Unknown istribution strategy"))
464
465
        return (local_data, hermitian)

466
    def globalize_flat_index(self, index):
467
        return int(index) + self.local_dim_offset
468

469
470
471
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
472
            raise TypeError(about._errors.cstring("ERROR: Length\
473
                of index tuple does not match the array's shape!"))
474
475
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
476
        # ensure that the globalized index list is within the bounds
477
        global_index_memory = globalized_index
478
        globalized_index = np.clip(globalized_index,
479
                                   -np.array(self.global_shape),
480
                                   np.array(self.global_shape) - 1)
481
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
482
            about.warnings.cprint("WARNING: Indices were clipped!")
483
484
        globalized_index = tuple(globalized_index)
        return globalized_index
485

486
    def _allgather(self, thing, comm=None):
Ultima's avatar
Ultima committed
487
        if comm is None:
488
            comm = self.comm
489
490
        gathered_things = comm.allgather(thing)
        return gathered_things
491

492
    def _Allreduce_helper(self, sendbuf, recvbuf, op):
Ultima's avatar
Ultima committed
493
494
495
496
        send_dtype = self._my_dtype_converter.to_mpi(sendbuf.dtype)
        recv_dtype = self._my_dtype_converter.to_mpi(recvbuf.dtype)
        self.comm.Allreduce([sendbuf, send_dtype],
                            [recvbuf, recv_dtype],
497
                            op=op)
Ultima's avatar
Ultima committed
498
499
        return recvbuf

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    def _contraction_helper(self, parent, function, axis=None, **kwargs):
        if axis == ():
            return parent.copy()

        old_shape = parent.shape
        axis = cast_axis_to_tuple(axis)
        if axis is None:
            new_shape = ()
        else:
            new_shape = tuple([old_shape[i] for i in xrange(len(old_shape))
                               if i not in axis])

        # do the contraction on the node's local data
        local_data = parent.data
        contracted_local_data = function(local_data, axis=axis, **kwargs)
        new_dtype = contracted_local_data.dtype

        # check if additional contraction along the first axis must be done
        if axis is None or 0 in axis:
            (mpi_op, bufferQ) = op_translate_dict[function]
520
521
            # check if allreduce must be used instead of Allreduce
            use_Uppercase = False
522
            if bufferQ and isinstance(contracted_local_data, np.ndarray):
523
524
525
526
527
                # MPI.MAX and MPI.MIN do not support complex data types
                if not np.issubdtype(contracted_local_data.dtype,
                                     np.complexfloating):
                    use_Uppercase = True
            if use_Uppercase:
528
529
530
531
532
533
534
535
536
537
538
                global_contracted_local_data = np.empty_like(
                    contracted_local_data)
                new_mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
                self.comm.Allreduce([contracted_local_data,
                                     new_mpi_dtype],
                                    [global_contracted_local_data,
                                     new_mpi_dtype],
                                    op=mpi_op)
            else:
                global_contracted_local_data = self.comm.allreduce(
                    contracted_local_data, op=mpi_op)
539
540
541
            new_dist_strategy = 'not'
        else:
            new_dist_strategy = parent.distribution_strategy
542
            global_contracted_local_data = contracted_local_data
543
544

        if new_shape == ():
545
            result = global_contracted_local_data
546
547
548
549
550
551
552
553
554
555
556
557
558
        else:
            # try to store the result in a distributed_data_object with the
            # distribution_strategy as parent
            result = parent.copy_empty(global_shape=new_shape,
                                       dtype=new_dtype,
                                       distribution_strategy=new_dist_strategy)

            # However, there are cases where the contracted data does not any
            # longer follow the prior distribution scheme.
            # Example: FFTW distribution on 4 MPI processes
            # Contracting (4, 4) to (4,).
            # (4, 4) was distributed (1, 4)...(1, 4)
            # (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
559
            if result.local_shape != global_contracted_local_data.shape:
560
                result = parent.copy_empty(
561
                                    local_shape=global_contracted_local_data.shape,
562
563
                                    dtype=new_dtype,
                                    distribution_strategy='freeform')
564
            result.set_local_data(global_contracted_local_data, copy=False)
565
566
567

        return result

568
    def distribute_data(self, data=None, alias=None,
569
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
570
        '''
571
        distribute data checks
ultimanet's avatar
ultimanet committed
572
573
574
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
575
576
577

        comm = self.comm

578
        if 'h5py' in gdi and alias is not None:
579
            data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
580
581
582

        local_data_available_Q = (data is not None)
        data_available_Q = np.array(comm.allgather(local_data_available_Q))
583

Ultima's avatar
Ultima committed
584
        if np.all(data_available_Q == False):
Ultimanet's avatar
Ultimanet committed
585
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
586
587
        # if all nodes got data, we assume that it is the right data and
        # store it individually.
Ultima's avatar
Ultima committed
588
589
        elif np.all(data_available_Q == True):
            if isinstance(data, distributed_data_object):
590
                temp_d2o = data.get_data((slice(self.local_start,
Ultima's avatar
Ultima committed
591
                                                self.local_end),),
592
593
594
595
                                         local_keys=True,
                                         copy=copy)
                return temp_d2o.get_local_data(copy=False).astype(self.dtype,
                                                                  copy=False)
596
            else:
Ultima's avatar
Ultima committed
597
                return data[self.local_start:self.local_end].astype(
598
599
                    self.dtype,
                    copy=copy)
ultimanet's avatar
ultimanet committed
600
        else:
Ultima's avatar
Ultima committed
601
602
            raise ValueError(
                "ERROR: distribute_data must get data on all nodes!")
603
604

    def _disperse_data_primitive(self, data, to_key, data_update, from_key,
605
606
                                 copy, to_found, to_found_boolean, from_found,
                                 from_found_boolean, **kwargs):
Ultima's avatar
Ultima committed
607
608
        if np.isscalar(data_update):
            from_key = None
609
610
611

        # Case 1: to_key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
612
613
        if to_found == 'slicetuple':
            if from_found == 'slicetuple':
614
615
616
617
618
619
                return self.disperse_data_to_slices(data=data,
                                                    to_slices=to_key,
                                                    data_update=data_update,
                                                    from_slices=from_key,
                                                    copy=copy,
                                                    **kwargs)
620
621
622
            else:
                if from_key is not None:
                    about.infos.cprint(
623
                        "INFO: Advanced injection is not available for this " +
624
                        "combination of to_key and from_key.")
625
626
627
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
628

629
630
631
632
633
                return self.disperse_data_to_slices(
                                            data=data,
                                            to_slices=to_key,
                                            data_update=prepared_data_update,
                                            copy=copy,
634
                                            **kwargs)
635
636

        # Case 2: key is an array
637
        elif (to_found == 'ndarray' or to_found == 'd2o'):
638
            # Case 2.1: The array is boolean.
639
            if to_found_boolean:
640
641
                if from_key is not None:
                    about.infos.cprint(
642
                        "INFO: Advanced injection is not available for this " +
643
                        "combination of to_key and from_key.")
644
645
646
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
647
648
649
650
651
652
                return self.disperse_data_to_bool(
                                              data=data,
                                              to_boolean_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
653
654
            # Case 2.2: The array is not boolean. Only 1-dimensional
            # advanced slicing is supported.
655
656
657
658
659
            else:
                if len(to_key.shape) != 1:
                    raise ValueError(about._errors.cstring(
                        "WARNING: Only one-dimensional advanced indexing " +
                        "is supported"))
660
                # Make a recursive call in order to trigger the 'list'-section
661
662
663
                return self.disperse_data(data=data, to_key=[to_key],
                                          data_update=data_update,
                                          from_key=from_key, copy=copy,
664
665
                                          **kwargs)

666
667
        # Case 3 : to_key is a list. This list is interpreted as
        # one-dimensional advanced indexing list.
668
669
670
        elif to_found == 'indexinglist':
            if from_key is not None:
                about.infos.cprint(
671
                    "INFO: Advanced injection is not available for this " +
672
                    "combination of to_key and from_key.")
673
674
675
                prepared_data_update = data_update[from_key]
            else:
                prepared_data_update = data_update
676
677
678
679
680
            return self.disperse_data_to_list(data=data,
                                              to_list_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
681
682

    def disperse_data_to_list(self, data, to_list_key, data_update,
683
                              copy=True, **kwargs):
684

685
686
        if to_list_key == []:
            return data
687

Ultima's avatar
Ultima committed
688
        local_to_list_key = self._advanced_index_decycler(to_list_key)
689
        return self._disperse_data_to_list_and_bool_helper(
690
691
692
693
694
            data=data,
            local_to_key=local_to_list_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
695
696

    def disperse_data_to_bool(self, data, to_boolean_key, data_update,
697
                              copy=True, **kwargs):
698
699
        # Extract the part of the to_boolean_key which corresponds to the
        # local data
700
701
        local_to_boolean_key = self.extract_local_data(to_boolean_key)
        return self._disperse_data_to_list_and_bool_helper(
702
703
704
705
706
            data=data,
            local_to_key=local_to_boolean_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
707

708
    def _disperse_data_to_list_and_bool_helper(self, data, local_to_key,
709
                                               data_update, copy, **kwargs):
710
711
        comm = self.comm
        rank = comm.rank
712
713
        # Infer the length and offset of the locally affected data
        locally_affected_data = data[local_to_key]
714
        data_length = np.shape(locally_affected_data)[0]
Ultima's avatar
Ultima committed
715
        data_length_list = comm.allgather(data_length)
716
717
718
        data_length_offset_list = np.append([0],
                                            np.cumsum(data_length_list)[:-1])

719
720
        # Update the local data object with its very own portion
        o = data_length_offset_list
721
        l = data_length
722

723
        if isinstance(data_update, distributed_data_object):
724
            local_data_update = data_update.get_data(
725
726
                                          slice(o[rank], o[rank] + l),
                                          local_keys=True
727
728
729
                                          ).get_local_data(copy=False)
            data[local_to_key] = local_data_update.astype(self.dtype,
                                                          copy=False)
Ultima's avatar
Ultima committed
730
731
732
        elif np.isscalar(data_update):
            data[local_to_key] = data_update
        else:
733
            data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
734
735
                                          copy=copy).astype(self.dtype,
                                                            copy=False)
736
737
738
        return data

    def disperse_data_to_slices(self, data, to_slices,
739
                                data_update, from_slices=None, copy=True):
740
741
742
743
        comm = self.comm
        (to_slices, sliceified) = self._sliceify(to_slices)

        # parse the to_slices object
744
745
746
747
748
        localized_to_start, localized_to_stop = self._backshift_and_decycle(
            to_slices[0], self.local_start, self.local_end,
            self.global_shape[0])
        local_to_slice = (slice(localized_to_start, localized_to_stop,
                                to_slices[0].step),) + to_slices[1:]
749
        local_to_slice_shape = data[local_to_slice].shape
750

Ultima's avatar
Ultima committed
751
752
753
        to_step = to_slices[0].step
        if to_step is None:
            to_step = 1
754
        elif to_step == 0:
755
            raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
756
757
                "ERROR: to_step size == 0!"))

758
759
760
761
        # Compute the offset of the data the individual node will take.
        # The offset is free of stepsizes. It is the offset in terms of
        # the purely transported data. If to_step < 0, the offset will
        # be calculated in reverse order
Ultima's avatar
Ultima committed
762
        order = np.sign(to_step)
763

Ultima's avatar
Ultima committed
764
        local_affected_data_length = local_to_slice_shape[0]
765
766
767
        local_affected_data_length_list = np.empty(comm.size, dtype=np.int)
        comm.Allgather(
            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],
768
            [local_affected_data_length_list, MPI.INT])
769
770
771
        local_affected_data_length_offset_list = np.append([0],
                                                           np.cumsum(
            local_affected_data_length_list[::order])[:-1])[::order]
Ultima's avatar
Ultima committed
772

773
        if np.isscalar(data_update):
Ultima's avatar
Ultima committed
774
775
            data[local_to_slice] = data_update
        else:
776
            # construct the locally adapted from_slice object
Ultima's avatar
Ultima committed
777
778
779
780
            r = comm.rank
            o = local_affected_data_length_offset_list
            l = local_affected_data_length

781
782
783
            data_update = self._enfold(data_update, sliceified)

            # parse the from_slices object
Ultima's avatar
Ultima committed
784
            if from_slices is None:
785
                from_slices = (slice(None, None, None),)
786
787
788
789
790
791
            (from_slices_start, from_slices_stop) = \
                self._backshift_and_decycle(
                                            slice_object=from_slices[0],
                                            shifted_start=0,
                                            shifted_stop=data_update.shape[0],
                                            global_length=data_update.shape[0])
Ultima's avatar
Ultima committed
792
            if from_slices_start is None:
793
794
795
                raise ValueError(about._errors.cstring(
                    "ERROR: _backshift_and_decycle should never return " +
                    "None for local_start!"))
796
797

            # parse the step sizes
798
            from_step = from_slices[0].step
Ultima's avatar
Ultima committed
799
            if from_step is None:
800
                from_step = 1
801
            elif from_step == 0:
802
                raise ValueError(about._errors.cstring(
803
                    "ERROR: from_step size == 0!"))
804

805
            localized_from_start = from_slices_start + from_step * o[r]
806
            localized_from_stop = localized_from_start + from_step * l
807
808
            if localized_from_stop < 0:
                localized_from_stop = None
809
810

            localized_from_slice = (slice(localized_from_start,
811
812
                                          localized_from_stop,
                                          from_step),)
813

814
            update_slice = localized_from_slice + from_slices[1:]
815
816

            if isinstance(data_update, distributed_data_object):
817
                selected_update = data_update.get_data(
818
                                 key=update_slice,
819
820
821
822
                                 local_keys=True)
                local_data_update = selected_update.get_local_data(copy=False)
                local_data_update = local_data_update.astype(self.dtype,
                                                             copy=False)
Ultima's avatar
Ultima committed
823
824
                if np.prod(np.shape(local_data_update)) != 0:
                    data[local_to_slice] = local_data_update
825
            # elif np.isscalar(data_update):
Ultima's avatar
Ultima committed
826
            #    data[local_to_slice] = data_update
827
828
            else:
                local_data_update = np.array(data_update)[update_slice]
Ultima's avatar
Ultima committed
829
                if np.prod(np.shape(local_data_update)) != 0:
830
831
                    data[local_to_slice] = np.array(
                                                local_data_update,
832
833
                                                copy=copy).astype(self.dtype,
                                                                  copy=False)
834

835
    def collect_data(self, data, key, local_keys=False, copy=True, **kwargs):
836
837
838
839
840
841
842
843
844
845
        # collect_data supports three types of keys
        # Case 1: key is a slicing/index tuple
        # Case 2: key is a boolean-array of the same shape as self
        # Case 3: key is a list of shape (n,), where n is
        #         0<n<len(self.shape). The entries of the list must be a
        #         scalar/list/tuple/ndarray. If not scalar the length must be
        #         the same for all of the lists. This is essentially
        #         numpy advanced indexing in one dimension, only.

        # Check which case we got:
Ultima's avatar
Ultima committed
846
        (found, found_boolean) = _infer_key_type(key)
847
        comm = self.comm
848
        if local_keys is False:
849
            return self._collect_data_primitive(data, key, found,
850
851
                                                found_boolean, copy=copy,
                                                **kwargs)
852
        else:
853
            # assert that all keys are from same type
854
            found_list = comm.allgather(fou