distributor_factory.py 89.3 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2

3
4
import numbers

ultimanet's avatar
ultimanet committed
5
import numpy as np
6

theos's avatar
theos committed
7
8
9
from nifty.keepers import about,\
                          global_configuration as gc,\
                          global_dependency_injector as gdi
10

theos's avatar
theos committed
11
from distributed_data_object import distributed_data_object
12

theos's avatar
theos committed
13
14
15
16
from d2o_iter import d2o_slicing_iter,\
                     d2o_not_iter
from d2o_librarian import d2o_librarian
from dtype_converter import dtype_converter
17
18
from cast_axis_to_tuple import cast_axis_to_tuple
from translate_to_mpi_operator import op_translate_dict
19

theos's avatar
theos committed
20
from strategies import STRATEGIES
21

theos's avatar
theos committed
22
23
24
MPI = gdi[gc['mpi_module']]
h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
ultimanet's avatar
ultimanet committed
25

26

27
class _distributor_factory(object):
28

29
30
    def __init__(self):
        self.distributor_store = {}
31

32
    def parse_kwargs(self, distribution_strategy, comm,
33
34
35
                     global_data=None, global_shape=None,
                     local_data=None, local_shape=None,
                     alias=None, path=None,
36
37
38
39
40
41
42
43
44
45
46
47
                     dtype=None, skip_parsing=False, **kwargs):

        if skip_parsing:
            return_dict = {'comm': comm,
                           'dtype': dtype,
                           'name': distribution_strategy
                           }
            if distribution_strategy in STRATEGIES['global']:
                return_dict['global_shape'] = global_shape
            elif distribution_strategy in STRATEGIES['local']:
                return_dict['local_shape'] = local_shape
            return return_dict
Ultima's avatar
Ultima committed
48

49
        return_dict = {}
50
51
52
53
54

        expensive_checks = gc['d2o_init_checks']

        # Parse the MPI communicator
        if comm is None:
Ultima's avatar
Ultima committed
55
            raise ValueError(about._errors.cstring(
56
57
58
59
60
61
62
63
64
65
66
                "ERROR: The distributor needs MPI-communicator object comm!"))
        else:
            return_dict['comm'] = comm

        if expensive_checks:
            # Check that all nodes got the same distribution_strategy
            strat_list = comm.allgather(distribution_strategy)
            if all(x == strat_list[0] for x in strat_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The distribution-strategy must be the same on " +
                    "all nodes!"))
67

68
        # Check for an hdf5 file and open it if given
69
        if 'h5py' in gdi and alias is not None:
70
71
72
            # set file path
            file_path = path if (path is not None) else alias
            # open hdf5 file
73
            if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
Ultima's avatar
Ultima committed
74
75
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
76
77
78
                f = h5py.File(file_path, 'r')
            # open alias in file
            dset = f[alias]
Ultima's avatar
Ultima committed
79
80
81
        else:
            dset = None

82
        # Parse the datatype
Ultima's avatar
Ultima committed
83
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
84
                (dset is not None):
85
            dtype = dset.dtype
86
87

        elif distribution_strategy in ['not', 'equal', 'fftw']:
Ultima's avatar
Ultima committed
88
            if dtype is None:
89
                if global_data is None:
Ultima's avatar
Ultima committed
90
91
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
92
                else:
Ultima's avatar
Ultima committed
93
                    try:
94
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
95
                    except(AttributeError):
96
97
98
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
99

100
        elif distribution_strategy in STRATEGIES['local']:
101
            if dtype is None:
Ultima's avatar
Ultima committed
102
103
104
                if isinstance(global_data, distributed_data_object):
                    dtype = global_data.dtype
                elif local_data is not None:
Ultima's avatar
Ultima committed
105
                    try:
106
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
107
                    except(AttributeError):
108
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
109
110
111
                else:
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
112

Ultima's avatar
Ultima committed
113
            else:
114
                dtype = np.dtype(dtype)
115
116
117
118
119
        if expensive_checks:
            dtype_list = comm.allgather(dtype)
            if all(x == dtype_list[0] for x in dtype_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The given dtype must be the same on all nodes!"))
Ultima's avatar
Ultima committed
120
        return_dict['dtype'] = dtype
121
122
123

        # Parse the shape
        # Case 1: global-type slicer
124
        if distribution_strategy in STRATEGIES['global']:
Ultima's avatar
Ultima committed
125
126
127
128
129
130
131
132
133
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
134
                    "global_shape nor hdf5 file supplied!"))
Ultima's avatar
Ultima committed
135
136
            if global_shape == ():
                raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
137
                    "ERROR: global_shape == () is not a valid shape!"))
138

139
140
141
142
143
144
145
            if expensive_checks:
                global_shape_list = comm.allgather(global_shape)
                if not all(x == global_shape_list[0]
                           for x in global_shape_list):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The global_shape must be the same on all " +
                        "nodes!"))
Ultima's avatar
Ultima committed
146
147
            return_dict['global_shape'] = global_shape

148
        # Case 2: local-type slicer
Ultima's avatar
Ultima committed
149
150
151
152
        elif distribution_strategy in ['freeform']:
            if isinstance(global_data, distributed_data_object):
                local_shape = global_data.local_shape
            elif local_data is not None and np.isscalar(local_data) == False:
Ultima's avatar
Ultima committed
153
154
155
156
157
158
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
159
                    "local_shape nor global d2o supplied!"))
Ultima's avatar
Ultima committed
160
161
            if local_shape == ():
                raise ValueError(about._errors.cstring(
162
163
                    "ERROR: local_shape == () is not a valid shape!"))

164
165
166
167
168
169
170
171
            if expensive_checks:
                local_shape_list = comm.allgather(local_shape[1:])
                cleared_set = set(local_shape_list)
                cleared_set.discard(())
                if len(cleared_set) > 1:
                    raise ValueError(about._errors.cstring(
                        "ERROR: All but the first entry of local_shape " +
                        "must be the same on all nodes!"))
Ultima's avatar
Ultima committed
172
173
            return_dict['local_shape'] = local_shape

174
        # Add the name of the distributor if needed
175
176
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
177
178

        # close the file-handle
Ultima's avatar
Ultima committed
179
180
181
        if dset is not None:
            f.close()

182
        return return_dict
183

Ultima's avatar
Ultima committed
184
    def hash_arguments(self, distribution_strategy, **kwargs):
185
        kwargs = kwargs.copy()
186

187
188
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
189

190
        if 'global_shape' in kwargs:
Ultima's avatar
Ultima committed
191
            kwargs['global_shape'] = kwargs['global_shape']
192
        if 'local_shape' in kwargs:
193
194
195
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
196

Ultima's avatar
Ultima committed
197
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
198
        kwargs['distribution_strategy'] = distribution_strategy
199

200
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
201

202
    def dictionize_np(self, x):
203
        dic = x.type.__dict__.items()
204
        if x is np.float:
205
            dic[24] = 0
206
207
            dic[29] = 0
            dic[37] = 0
208
209
        return frozenset(dic)

210
    def get_distributor(self, distribution_strategy, comm, **kwargs):
211
        # check if the distribution strategy is known
212
        if distribution_strategy not in STRATEGIES['all']:
213
            raise ValueError(about._errors.cstring(
214
                "ERROR: Unknown distribution strategy supplied."))
215
216

        # parse the kwargs
Ultima's avatar
Ultima committed
217
        parsed_kwargs = self.parse_kwargs(
218
219
220
            distribution_strategy=distribution_strategy,
            comm=comm,
            **kwargs)
221

Ultima's avatar
Ultima committed
222
223
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
224
        # check if the distributors has already been produced in the past
225
        if hashed_kwargs in self.distributor_store:
Ultima's avatar
Ultima committed
226
            return self.distributor_store[hashed_kwargs]
227
        else:
228
            # produce new distributor
229
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
230
                produced_distributor = _not_distributor(**parsed_kwargs)
231

232
233
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
234
235
                    slicer=_equal_slicer,
                    **parsed_kwargs)
236

237
238
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
239
240
                    slicer=_fftw_slicer,
                    **parsed_kwargs)
Ultima's avatar
Ultima committed
241
242
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
243
244
                    slicer=_freeform_slicer,
                    **parsed_kwargs)
245
246

            self.distributor_store[hashed_kwargs] = produced_distributor
Ultima's avatar
Ultima committed
247
            return self.distributor_store[hashed_kwargs]
248
249


250
distributor_factory = _distributor_factory()
Ultima's avatar
Ultima committed
251
252
253
254
255
256


def _infer_key_type(key):
    if key is None:
        return (None, None)
    found_boolean = False
257
    # Check which case we got:
258
259
260
    if isinstance(key, slice) or np.isscalar(key):
        found = 'slicetuple'
    elif isinstance(key, tuple) or isinstance(key, list):
261
262
        # Check if there is something different in the array than
        # scalars and slices
Ultima's avatar
Ultima committed
263
264
265
        scalarQ = np.array(map(np.isscalar, key))
        sliceQ = np.array(map(lambda z: isinstance(z, slice), key))
        if np.all(scalarQ + sliceQ):
266
            found = 'slicetuple'
Ultima's avatar
Ultima committed
267
268
269
270
271
272
273
274
        else:
            found = 'indexinglist'
    elif isinstance(key, np.ndarray):
        found = 'ndarray'
        found_boolean = (key.dtype.type == np.bool_)
    elif isinstance(key, distributed_data_object):
        found = 'd2o'
        found_boolean = (key.dtype == np.bool_)
275
276
    else:
        raise ValueError(about._errors.cstring("ERROR: Unknown keytype!"))
Ultima's avatar
Ultima committed
277
278
279
280
    return (found, found_boolean)


class distributor(object):
281
282

    def disperse_data(self, data, to_key, data_update, from_key=None,
Ultima's avatar
Ultima committed
283
                      local_keys=False, copy=True, **kwargs):
284
        # Check which keys we got:
Ultima's avatar
Ultima committed
285
286
287
        (to_found, to_found_boolean) = _infer_key_type(to_key)
        (from_found, from_found_boolean) = _infer_key_type(from_key)

288
        comm = self.comm
289
290
291
292
293
294
295
296
297
298
299
300
        if local_keys is False:
            return self._disperse_data_primitive(
                                         data=data,
                                         to_key=to_key,
                                         data_update=data_update,
                                         from_key=from_key,
                                         copy=copy,
                                         to_found=to_found,
                                         to_found_boolean=to_found_boolean,
                                         from_found=from_found,
                                         from_found_boolean=from_found_boolean,
                                         **kwargs)
301

Ultima's avatar
Ultima committed
302
        else:
303
            # assert that all to_keys are from same type
Ultima's avatar
Ultima committed
304
            to_found_list = comm.allgather(to_found)
305
            assert(all(x == to_found_list[0] for x in to_found_list))
Ultima's avatar
Ultima committed
306
            to_found_boolean_list = comm.allgather(to_found_boolean)
307
308
            assert(all(x == to_found_boolean_list[0] for x in
                       to_found_boolean_list))
Ultima's avatar
Ultima committed
309
            from_found_list = comm.allgather(from_found)
310
            assert(all(x == from_found_list[0] for x in from_found_list))
Ultima's avatar
Ultima committed
311
            from_found_boolean_list = comm.allgather(from_found_boolean)
312
313
            assert(all(x == from_found_boolean_list[0] for
                       x in from_found_boolean_list))
Ultima's avatar
Ultima committed
314

315
316
317
            # gather the local to_keys into a global to_key_list
            # Case 1: the to_keys are not distributed_data_objects
            # -> allgather does the job
Ultima's avatar
Ultima committed
318
319
            if to_found != 'd2o':
                to_key_list = comm.allgather(to_key)
320
321
322
            # Case 2: if the to_keys are distributed_data_objects, gather
            # the index of the array and build the to_key_list with help
            # from the librarian
Ultima's avatar
Ultima committed
323
324
325
326
            else:
                to_index_list = comm.allgather(to_key.index)
                to_key_list = map(lambda z: d2o_librarian[z], to_index_list)

327
            # gather the local from_keys. It is the same procedure as above
Ultima's avatar
Ultima committed
328
            if from_found != 'd2o':
329
                from_key_list = comm.allgather(from_key)
Ultima's avatar
Ultima committed
330
331
            else:
                from_index_list = comm.allgather(from_key.index)
332
333
                from_key_list = map(lambda z: d2o_librarian[z],
                                    from_index_list)
334

Ultima's avatar
Ultima committed
335
            local_data_update_is_scalar = np.isscalar(data_update)
336
            local_scalar_list = comm.allgather(local_data_update_is_scalar)
Ultima's avatar
Ultima committed
337
            for i in xrange(len(to_key_list)):
338
                if np.all(np.array(local_scalar_list) == True):
Ultima's avatar
Ultima committed
339
340
341
342
343
344
                    scalar_list = comm.allgather(data_update)
                    temp_data_update = scalar_list[i]
                elif isinstance(data_update, distributed_data_object):
                    data_update_index_list = comm.allgather(data_update.index)
                    data_update_list = map(lambda z: d2o_librarian[z],
                                           data_update_index_list)
345
                    temp_data_update = data_update_list[i]
Ultima's avatar
Ultima committed
346
                else:
347
348
                    # build a temporary freeform d2o which only contains data
                    # from node i
Ultima's avatar
Ultima committed
349
350
351
                    if comm.rank == i:
                        temp_shape = np.shape(data_update)
                        try:
352
                            temp_dtype = np.dtype(data_update.dtype)
Ultima's avatar
Ultima committed
353
                        except(TypeError):
354
                            temp_dtype = np.array(data_update).dtype
Ultima's avatar
Ultima committed
355
356
357
358
359
                    else:
                        temp_shape = None
                        temp_dtype = None
                    temp_shape = comm.bcast(temp_shape, root=i)
                    temp_dtype = comm.bcast(temp_dtype, root=i)
360

Ultima's avatar
Ultima committed
361
362
363
364
                    if comm.rank != i:
                        temp_shape = list(temp_shape)
                        temp_shape[0] = 0
                        temp_shape = tuple(temp_shape)
365
                        temp_data = np.empty(temp_shape, dtype=temp_dtype)
Ultima's avatar
Ultima committed
366
367
368
                    else:
                        temp_data = data_update
                    temp_data_update = distributed_data_object(
369
370
371
372
                                        local_data=temp_data,
                                        distribution_strategy='freeform',
                                        copy=False,
                                        comm=self.comm)
Ultima's avatar
Ultima committed
373
                # disperse the data one after another
374
375
376
377
378
379
380
381
382
383
384
                self._disperse_data_primitive(
                                      data=data,
                                      to_key=to_key_list[i],
                                      data_update=temp_data_update,
                                      from_key=from_key_list[i],
                                      copy=copy,
                                      to_found=to_found,
                                      to_found_boolean=to_found_boolean,
                                      from_found=from_found,
                                      from_found_boolean=from_found_boolean,
                                      **kwargs)
385
386
                i += 1

387

Ultima's avatar
Ultima committed
388
class _slicing_distributor(distributor):
389
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
390

Ultima's avatar
Ultima committed
391
392
        self.comm = comm
        self.distribution_strategy = name
393
        self.dtype = np.dtype(dtype)
394

theos's avatar
theos committed
395
        self._my_dtype_converter = dtype_converter
396

ultimanet's avatar
ultimanet committed
397
        if not self._my_dtype_converter.known_np_Q(self.dtype):
398
            raise TypeError(about._errors.cstring(
399
400
                "ERROR: The datatype " + str(self.dtype.__repr__()) +
                " is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
401

402
        self.mpi_dtype = self._my_dtype_converter.to_mpi(self.dtype)
403
404

        self.slicer = slicer
405
        self._local_size = self.slicer(comm=comm, **remaining_parsed_kwargs)
406
        self.local_start = self._local_size[0]
407
        self.local_end = self._local_size[1]
Ultima's avatar
Ultima committed
408
        self.global_shape = self._local_size[2]
409
        self.global_dim = reduce(lambda x, y: x*y, self.global_shape)
410

411
        self.local_length = self.local_end - self.local_start
ultimanet's avatar
ultimanet committed
412
413
414
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
415
416
        comm.Allgather([np.array(self.local_dim, dtype=np.int), MPI.INT],
                       [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
417
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
418

419
420
421
422
        self.local_slice = np.array([self.local_start, self.local_end,
                                     self.local_length, self.local_dim,
                                     self.local_dim_offset],
                                    dtype=np.int)
423
        # collect all local_slices
424
425
426
        self.all_local_slices = np.empty((comm.size, 5), dtype=np.int)
        comm.Allgather([np.array((self.local_slice,), dtype=np.int), MPI.INT],
                       [self.all_local_slices, MPI.INT])
427
428

    def initialize_data(self, global_data, local_data, alias, path, hermitian,
Ultima's avatar
Ultima committed
429
                        copy, **kwargs):
430
        if 'h5py' in gdi and alias is not None:
431
            local_data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
432
            return (local_data, hermitian)
433
434

        if self.distribution_strategy in ['equal', 'fftw']:
Ultima's avatar
Ultima committed
435
            if np.isscalar(global_data):
436
                local_data = np.empty(self.local_shape, dtype=self.dtype)
437
                local_data.fill(global_data)
Ultima's avatar
Ultima committed
438
439
                hermitian = True
            else:
440
441
                local_data = self.distribute_data(data=global_data,
                                                  copy=copy)
Ultima's avatar
Ultima committed
442
        elif self.distribution_strategy in ['freeform']:
Ultima's avatar
Ultima committed
443
            if isinstance(global_data, distributed_data_object):
444
                local_data = global_data.get_local_data(copy=copy)
Ultima's avatar
Ultima committed
445
            elif np.isscalar(local_data):
446
                temp_local_data = np.empty(self.local_shape,
447
                                           dtype=self.dtype)
448
                temp_local_data.fill(local_data)
449
                local_data = temp_local_data
Ultima's avatar
Ultima committed
450
451
                hermitian = True
            elif local_data is None:
452
                local_data = np.empty(self.local_shape, dtype=self.dtype)
453
454
455
            elif isinstance(local_data, np.ndarray):
                local_data = local_data.astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
456
457
            else:
                local_data = np.array(local_data).astype(
458
                    self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
459
460
        else:
            raise TypeError(about._errors.cstring(
461
                "ERROR: Unknown istribution strategy"))
462
463
        return (local_data, hermitian)

464
    def globalize_flat_index(self, index):
465
        return int(index) + self.local_dim_offset
466

467
468
469
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
470
            raise TypeError(about._errors.cstring("ERROR: Length\
471
                of index tuple does not match the array's shape!"))
472
473
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
474
        # ensure that the globalized index list is within the bounds
475
        global_index_memory = globalized_index
476
        globalized_index = np.clip(globalized_index,
477
                                   -np.array(self.global_shape),
478
                                   np.array(self.global_shape) - 1)
479
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
480
            about.warnings.cprint("WARNING: Indices were clipped!")
481
482
        globalized_index = tuple(globalized_index)
        return globalized_index
483

484
    def _allgather(self, thing, comm=None):
Ultima's avatar
Ultima committed
485
        if comm is None:
486
            comm = self.comm
487
488
        gathered_things = comm.allgather(thing)
        return gathered_things
489

490
    def _Allreduce_helper(self, sendbuf, recvbuf, op):
Ultima's avatar
Ultima committed
491
492
493
494
        send_dtype = self._my_dtype_converter.to_mpi(sendbuf.dtype)
        recv_dtype = self._my_dtype_converter.to_mpi(recvbuf.dtype)
        self.comm.Allreduce([sendbuf, send_dtype],
                            [recvbuf, recv_dtype],
495
                            op=op)
Ultima's avatar
Ultima committed
496
497
        return recvbuf

498
499
500
501
502
    def _selective_allreduce(self, data, op, bufferQ=False):
        size = self.comm.size
        rank = self.comm.rank

        if size == 1:
503
504
            if data is None:
                raise ValueError("ERROR: No process with non-None data.")
505
506
507
508
509
510
511
512
            result_data = data
        else:

            # infer which data should be included in the allreduce and if its
            # array data
            if data is None:
                got_array = np.array([0])
            elif not isinstance(data, np.ndarray):
513
514
                got_array = np.array([2])
            elif reduce(lambda x, y: x*y, data.shape) == 0:
515
516
517
518
                got_array = np.array([1])
            elif np.issubdtype(data.dtype, np.complexfloating):
                # MPI.MAX and MPI.MIN do not support complex data types
                got_array = np.array([3])
519
520
            else:
                got_array = np.array([4])
521
522
523
524
525

            got_array_list = np.empty(size, dtype=np.int)
            self.comm.Allgather([got_array, MPI.INT],
                                [got_array_list, MPI.INT])

526
527
528
            if reduce(lambda x, y: x & y, got_array_list == 1):
                return data

529
530
            # get first node with non-None data
            try:
531
                start = next(i for i in xrange(size) if got_array_list[i] > 1)
532
533
534
535
536
            except(StopIteration):
                raise ValueError("ERROR: No process with non-None data.")

            # check if the Uppercase function can be used or not
            # -> check if op supports buffers and if we got real array-data
537
            if bufferQ and got_array_list[start] == 4:
538
539
540
541
542
543
544
545
546
547
548
549
550
                # Send the dtype and shape from the start process to the others
                (new_dtype,
                 new_shape) = self.comm.bcast((data.dtype,
                                               data.shape), root=start)
                mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
                if rank == start:
                    result_data = data
                else:
                    result_data = np.empty(new_shape, dtype=new_dtype)

                self.comm.Bcast([result_data, mpi_dtype], root=start)

                for i in xrange(start+1, size):
551
                    if got_array_list[i] > 1:
552
553
554
555
556
557
558
559
560
561
                        if rank == i:
                            temp_data = data
                        else:
                            temp_data = np.empty(new_shape, dtype=new_dtype)
                        self.comm.Bcast([temp_data, mpi_dtype], root=i)
                        result_data = op(result_data, temp_data)

            else:
                result_data = self.comm.bcast(data, root=start)
                for i in xrange(start+1, size):
562
                    if got_array_list[i] > 1:
563
564
565
566
567
568
                        temp_data = self.comm.bcast(data, root=i)
                        result_data = op(result_data, temp_data)
        return result_data

    def contraction_helper(self, parent, function, allow_empty_contractions,
                           axis=None, **kwargs):
569
570
571
572
573
574
575
576
577
578
579
580
        if axis == ():
            return parent.copy()

        old_shape = parent.shape
        axis = cast_axis_to_tuple(axis)
        if axis is None:
            new_shape = ()
        else:
            new_shape = tuple([old_shape[i] for i in xrange(len(old_shape))
                               if i not in axis])

        local_data = parent.data
581

582
        try:
583
            contracted_local_data = function(local_data, axis=axis, **kwargs)
584
585
        except(ValueError):
            contracted_local_data = None
586
587
588
589

        # check if additional contraction along the first axis must be done
        if axis is None or 0 in axis:
            (mpi_op, bufferQ) = op_translate_dict[function]
590
591
592
593
            contracted_global_data = self._selective_allreduce(
                                        contracted_local_data,
                                        mpi_op,
                                        bufferQ)
594
595
            new_dist_strategy = 'not'
        else:
596
597
598
            if contracted_local_data is None:
                # raise the exception implicitly
                function(local_data, axis=axis, **kwargs)
599
            contracted_global_data = contracted_local_data
600
            new_dist_strategy = parent.distribution_strategy
601

602
        if new_shape == ():
603
            result = contracted_global_data
604
        else:
605
            new_dtype = contracted_global_data.dtype
606
607
608
609
610
611
612
613
614
615
616
617
            # try to store the result in a distributed_data_object with the
            # distribution_strategy as parent
            result = parent.copy_empty(global_shape=new_shape,
                                       dtype=new_dtype,
                                       distribution_strategy=new_dist_strategy)

            # However, there are cases where the contracted data does not any
            # longer follow the prior distribution scheme.
            # Example: FFTW distribution on 4 MPI processes
            # Contracting (4, 4) to (4,).
            # (4, 4) was distributed (1, 4)...(1, 4)
            # (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
618
            if result.local_shape != contracted_global_data.shape:
619
                result = parent.copy_empty(
620
                                    local_shape=contracted_global_data.shape,
621
622
                                    dtype=new_dtype,
                                    distribution_strategy='freeform')
623
            result.set_local_data(contracted_global_data, copy=False)
624
625
626

        return result

627
    def distribute_data(self, data=None, alias=None,
628
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
629
        '''
630
        distribute data checks
ultimanet's avatar
ultimanet committed
631
632
633
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
634
635
636

        comm = self.comm

637
        if 'h5py' in gdi and alias is not None:
638
            data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
639
640
641

        local_data_available_Q = (data is not None)
        data_available_Q = np.array(comm.allgather(local_data_available_Q))
642

Ultima's avatar
Ultima committed
643
        if np.all(data_available_Q == False):
Ultimanet's avatar
Ultimanet committed
644
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
645
646
        # if all nodes got data, we assume that it is the right data and
        # store it individually.
Ultima's avatar
Ultima committed
647
648
        elif np.all(data_available_Q == True):
            if isinstance(data, distributed_data_object):
649
                temp_d2o = data.get_data((slice(self.local_start,
Ultima's avatar
Ultima committed
650
                                                self.local_end),),
651
652
653
654
                                         local_keys=True,
                                         copy=copy)
                return temp_d2o.get_local_data(copy=False).astype(self.dtype,
                                                                  copy=False)
655
656
            elif np.isscalar(data):
                return np.ones(self.local_shape, dtype=self.dtype)*data
657
            else:
Ultima's avatar
Ultima committed
658
                return data[self.local_start:self.local_end].astype(
659
660
                    self.dtype,
                    copy=copy)
ultimanet's avatar
ultimanet committed
661
        else:
Ultima's avatar
Ultima committed
662
663
            raise ValueError(
                "ERROR: distribute_data must get data on all nodes!")
664
665

    def _disperse_data_primitive(self, data, to_key, data_update, from_key,
666
667
                                 copy, to_found, to_found_boolean, from_found,
                                 from_found_boolean, **kwargs):
Ultima's avatar
Ultima committed
668
669
        if np.isscalar(data_update):
            from_key = None
670
671
672

        # Case 1: to_key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
673
674
        if to_found == 'slicetuple':
            if from_found == 'slicetuple':
675
676
677
678
679
680
                return self.disperse_data_to_slices(data=data,
                                                    to_slices=to_key,
                                                    data_update=data_update,
                                                    from_slices=from_key,
                                                    copy=copy,
                                                    **kwargs)
681
682
683
            else:
                if from_key is not None:
                    about.infos.cprint(
684
                        "INFO: Advanced injection is not available for this " +
685
                        "combination of to_key and from_key.")
686
687
688
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
689

690
691
692
693
694
                return self.disperse_data_to_slices(
                                            data=data,
                                            to_slices=to_key,
                                            data_update=prepared_data_update,
                                            copy=copy,
695
                                            **kwargs)
696
697

        # Case 2: key is an array
698
        elif (to_found == 'ndarray' or to_found == 'd2o'):
699
            # Case 2.1: The array is boolean.
700
            if to_found_boolean:
701
702
                if from_key is not None:
                    about.infos.cprint(
703
                        "INFO: Advanced injection is not available for this " +
704
                        "combination of to_key and from_key.")
705
706
707
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
708
709
710
711
712
713
                return self.disperse_data_to_bool(
                                              data=data,
                                              to_boolean_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
714
715
            # Case 2.2: The array is not boolean. Only 1-dimensional
            # advanced slicing is supported.
716
717
718
719
720
            else:
                if len(to_key.shape) != 1:
                    raise ValueError(about._errors.cstring(
                        "WARNING: Only one-dimensional advanced indexing " +
                        "is supported"))
721
                # Make a recursive call in order to trigger the 'list'-section
722
723
724
                return self.disperse_data(data=data, to_key=[to_key],
                                          data_update=data_update,
                                          from_key=from_key, copy=copy,
725
726
                                          **kwargs)

727
728
        # Case 3 : to_key is a list. This list is interpreted as
        # one-dimensional advanced indexing list.
729
730
731
        elif to_found == 'indexinglist':
            if from_key is not None:
                about.infos.cprint(
732
                    "INFO: Advanced injection is not available for this " +
733
                    "combination of to_key and from_key.")
734
735
736
                prepared_data_update = data_update[from_key]
            else:
                prepared_data_update = data_update
737
738
739
740
741
            return self.disperse_data_to_list(data=data,
                                              to_list_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
742
743

    def disperse_data_to_list(self, data, to_list_key, data_update,
744
                              copy=True, **kwargs):
745

746
747
        if to_list_key == []:
            return data
748

Ultima's avatar
Ultima committed
749
        local_to_list_key = self._advanced_index_decycler(to_list_key)
750
        return self._disperse_data_to_list_and_bool_helper(
751
752
753
754
755
            data=data,
            local_to_key=local_to_list_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
756
757

    def disperse_data_to_bool(self, data, to_boolean_key, data_update,
758
                              copy=True, **kwargs):
759
760
        # Extract the part of the to_boolean_key which corresponds to the
        # local data
761
762
        local_to_boolean_key = self.extract_local_data(to_boolean_key)
        return self._disperse_data_to_list_and_bool_helper(
763
764
765
766
767
            data=data,
            local_to_key=local_to_boolean_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
768

769
    def _disperse_data_to_list_and_bool_helper(self, data, local_to_key,
770
                                               data_update, copy, **kwargs):
771
772
        comm = self.comm
        rank = comm.rank
773
774
        # Infer the length and offset of the locally affected data
        locally_affected_data = data[local_to_key]
775
        data_length = np.shape(locally_affected_data)[0]
Ultima's avatar
Ultima committed
776
        data_length_list = comm.allgather(data_length)
777
778
779
        data_length_offset_list = np.append([0],
                                            np.cumsum(data_length_list)[:-1])

780
781
        # Update the local data object with its very own portion
        o = data_length_offset_list
782
        l = data_length
783

784
        if isinstance(data_update, distributed_data_object):
785
            local_data_update = data_update.get_data(
786
787
                                          slice(o[rank], o[rank] + l),
                                          local_keys=True
788
789
790
                                          ).get_local_data(copy=False)
            data[local_to_key] = local_data_update.astype(self.dtype,
                                                          copy=False)
Ultima's avatar
Ultima committed
791
792
793
        elif np.isscalar(data_update):
            data[local_to_key] = data_update
        else:
794
            data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
795
796
                                          copy=copy).astype(self.dtype,
                                                            copy=False)
797
798
799
        return data

    def disperse_data_to_slices(self, data, to_slices,
800
                                data_update, from_slices=None, copy=True):
801
802
803
804
        comm = self.comm
        (to_slices, sliceified) = self._sliceify(to_slices)

        # parse the to_slices object
805
806
807
808
809
        localized_to_start, localized_to_stop = self._backshift_and_decycle(
            to_slices[0], self.local_start, self.local_end,
            self.global_shape[0])
        local_to_slice = (slice(localized_to_start, localized_to_stop,
                                to_slices[0].step),) + to_slices[1:]
810
        local_to_slice_shape = data[local_to_slice].shape
811

Ultima's avatar
Ultima committed
812
813
814
        to_step = to_slices[0].step
        if to_step is None:
            to_step = 1
815
        elif to_step == 0:
816
            raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
817
818
                "ERROR: to_step size == 0!"))

819
820
821
822
        # Compute the offset of the data the individual node will take.
        # The offset is free of stepsizes. It is the offset in terms of
        # the purely transported data. If to_step < 0, the offset will
        # be calculated in reverse order
Ultima's avatar
Ultima committed
823
        order = np.sign(to_step)
824

Ultima's avatar
Ultima committed
825
        local_affected_data_length = local_to_slice_shape[0]
826
827
828
        local_affected_data_length_list = np.empty(comm.size, dtype=np.int)
        comm.Allgather(
            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],
829
            [local_affected_data_length_list, MPI.INT])
830
831
832
        local_affected_data_length_offset_list = np.append([0],
                                                           np.cumsum(
            local_affected_data_length_list[::order])[:-1])[::order]
Ultima's avatar
Ultima committed
833

834
        if np.isscalar(data_update):
Ultima's avatar
Ultima committed
835
836
            data[local_to_slice] = data_update
        else:
837
            # construct the locally adapted from_slice object
Ultima's avatar
Ultima committed
838
839
840
841
            r = comm.rank
            o = local_affected_data_length_offset_list
            l = local_affected_data_length

842
843
844
            data_update = self._enfold(data_update, sliceified)

            # parse the from_slices object
Ultima's avatar
Ultima committed
845
            if from_slices is None:
846
                from_slices = (slice(None, None, None),)
847
848
849
850
851
852
            (from_slices_start, from_slices_stop) = \
                self._backshift_and_decycle(
                                            slice_object=from_slices[0],
                                            shifted_start=0,
                                            shifted_stop=data_update.shape[0],
                                            global_length=data_update.shape[0])
Ultima's avatar
Ultima committed
853
            if from_slices_start is None:
854
855
856
                raise ValueError(about._errors.cstring(
                    "ERROR: _backshift_and_decycle should never return " +
                    "None for local_start!"))
857
858

            # parse the step sizes
859
            from_step = from_slices[0].step
Ultima's avatar
Ultima committed
860
            if from_step is None:
861
                from_step = 1
862
            elif from_step == 0:
863
                raise ValueError(about._errors.cstring(
864
                    "ERROR: from_step size == 0!"))
865

866
            localized_from_start = from_slices_start + from_step * o[r]