distributor_factory.py 88.3 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2

3
4
import numbers

ultimanet's avatar
ultimanet committed
5
import numpy as np
6

theos's avatar
theos committed
7
8
9
from nifty.keepers import about,\
                          global_configuration as gc,\
                          global_dependency_injector as gdi
10

theos's avatar
theos committed
11
from distributed_data_object import distributed_data_object
12

theos's avatar
theos committed
13
14
15
16
from d2o_iter import d2o_slicing_iter,\
                     d2o_not_iter
from d2o_librarian import d2o_librarian
from dtype_converter import dtype_converter
17
18
from cast_axis_to_tuple import cast_axis_to_tuple
from translate_to_mpi_operator import op_translate_dict
19

theos's avatar
theos committed
20
from strategies import STRATEGIES
21

theos's avatar
theos committed
22
23
24
MPI = gdi[gc['mpi_module']]
h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
ultimanet's avatar
ultimanet committed
25

26

27
class _distributor_factory(object):
28

29
30
    def __init__(self):
        self.distributor_store = {}
31

32
    def parse_kwargs(self, distribution_strategy, comm,
33
34
35
                     global_data=None, global_shape=None,
                     local_data=None, local_shape=None,
                     alias=None, path=None,
36
37
38
39
40
41
42
43
44
45
46
47
                     dtype=None, skip_parsing=False, **kwargs):

        if skip_parsing:
            return_dict = {'comm': comm,
                           'dtype': dtype,
                           'name': distribution_strategy
                           }
            if distribution_strategy in STRATEGIES['global']:
                return_dict['global_shape'] = global_shape
            elif distribution_strategy in STRATEGIES['local']:
                return_dict['local_shape'] = local_shape
            return return_dict
Ultima's avatar
Ultima committed
48

49
        return_dict = {}
50
51
52
53
54

        expensive_checks = gc['d2o_init_checks']

        # Parse the MPI communicator
        if comm is None:
Ultima's avatar
Ultima committed
55
            raise ValueError(about._errors.cstring(
56
57
58
59
60
61
62
63
64
65
66
                "ERROR: The distributor needs MPI-communicator object comm!"))
        else:
            return_dict['comm'] = comm

        if expensive_checks:
            # Check that all nodes got the same distribution_strategy
            strat_list = comm.allgather(distribution_strategy)
            if all(x == strat_list[0] for x in strat_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The distribution-strategy must be the same on " +
                    "all nodes!"))
67

68
        # Check for an hdf5 file and open it if given
69
        if 'h5py' in gdi and alias is not None:
70
71
72
            # set file path
            file_path = path if (path is not None) else alias
            # open hdf5 file
73
            if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
Ultima's avatar
Ultima committed
74
75
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
76
77
78
                f = h5py.File(file_path, 'r')
            # open alias in file
            dset = f[alias]
Ultima's avatar
Ultima committed
79
80
81
        else:
            dset = None

82
        # Parse the datatype
Ultima's avatar
Ultima committed
83
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
84
                (dset is not None):
85
            dtype = dset.dtype
86
87

        elif distribution_strategy in ['not', 'equal', 'fftw']:
Ultima's avatar
Ultima committed
88
            if dtype is None:
89
                if global_data is None:
Ultima's avatar
Ultima committed
90
91
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
92
                else:
Ultima's avatar
Ultima committed
93
                    try:
94
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
95
                    except(AttributeError):
96
97
98
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
99

100
        elif distribution_strategy in STRATEGIES['local']:
101
            if dtype is None:
Ultima's avatar
Ultima committed
102
103
104
                if isinstance(global_data, distributed_data_object):
                    dtype = global_data.dtype
                elif local_data is not None:
Ultima's avatar
Ultima committed
105
                    try:
106
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
107
                    except(AttributeError):
108
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
109
110
111
                else:
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
112

Ultima's avatar
Ultima committed
113
            else:
114
                dtype = np.dtype(dtype)
115
116
117
118
119
        if expensive_checks:
            dtype_list = comm.allgather(dtype)
            if all(x == dtype_list[0] for x in dtype_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The given dtype must be the same on all nodes!"))
Ultima's avatar
Ultima committed
120
        return_dict['dtype'] = dtype
121
122
123

        # Parse the shape
        # Case 1: global-type slicer
124
        if distribution_strategy in STRATEGIES['global']:
Ultima's avatar
Ultima committed
125
126
127
128
129
130
131
132
133
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
134
                    "global_shape nor hdf5 file supplied!"))
Ultima's avatar
Ultima committed
135
136
            if global_shape == ():
                raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
137
                    "ERROR: global_shape == () is not a valid shape!"))
138

139
140
141
142
143
144
145
            if expensive_checks:
                global_shape_list = comm.allgather(global_shape)
                if not all(x == global_shape_list[0]
                           for x in global_shape_list):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The global_shape must be the same on all " +
                        "nodes!"))
Ultima's avatar
Ultima committed
146
147
            return_dict['global_shape'] = global_shape

148
        # Case 2: local-type slicer
Ultima's avatar
Ultima committed
149
150
151
152
        elif distribution_strategy in ['freeform']:
            if isinstance(global_data, distributed_data_object):
                local_shape = global_data.local_shape
            elif local_data is not None and np.isscalar(local_data) == False:
Ultima's avatar
Ultima committed
153
154
155
156
157
158
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
159
                    "local_shape nor global d2o supplied!"))
Ultima's avatar
Ultima committed
160
161
            if local_shape == ():
                raise ValueError(about._errors.cstring(
162
163
                    "ERROR: local_shape == () is not a valid shape!"))

164
165
166
167
168
169
170
171
            if expensive_checks:
                local_shape_list = comm.allgather(local_shape[1:])
                cleared_set = set(local_shape_list)
                cleared_set.discard(())
                if len(cleared_set) > 1:
                    raise ValueError(about._errors.cstring(
                        "ERROR: All but the first entry of local_shape " +
                        "must be the same on all nodes!"))
Ultima's avatar
Ultima committed
172
173
            return_dict['local_shape'] = local_shape

174
        # Add the name of the distributor if needed
175
176
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
177
178

        # close the file-handle
Ultima's avatar
Ultima committed
179
180
181
        if dset is not None:
            f.close()

182
        return return_dict
183

Ultima's avatar
Ultima committed
184
    def hash_arguments(self, distribution_strategy, **kwargs):
185
        kwargs = kwargs.copy()
186

187
188
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
189

190
        if 'global_shape' in kwargs:
Ultima's avatar
Ultima committed
191
            kwargs['global_shape'] = kwargs['global_shape']
192
        if 'local_shape' in kwargs:
193
194
195
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
196

Ultima's avatar
Ultima committed
197
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
198
        kwargs['distribution_strategy'] = distribution_strategy
199

200
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
201

202
    def dictionize_np(self, x):
203
        dic = x.type.__dict__.items()
204
        if x is np.float:
205
            dic[24] = 0
206
207
            dic[29] = 0
            dic[37] = 0
208
209
        return frozenset(dic)

210
    def get_distributor(self, distribution_strategy, comm, **kwargs):
211
        # check if the distribution strategy is known
212
        if distribution_strategy not in STRATEGIES['all']:
213
            raise ValueError(about._errors.cstring(
214
                "ERROR: Unknown distribution strategy supplied."))
215
216

        # parse the kwargs
Ultima's avatar
Ultima committed
217
        parsed_kwargs = self.parse_kwargs(
218
219
220
            distribution_strategy=distribution_strategy,
            comm=comm,
            **kwargs)
221

Ultima's avatar
Ultima committed
222
223
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
224
        # check if the distributors has already been produced in the past
225
        if hashed_kwargs in self.distributor_store:
Ultima's avatar
Ultima committed
226
            return self.distributor_store[hashed_kwargs]
227
        else:
228
            # produce new distributor
229
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
230
                produced_distributor = _not_distributor(**parsed_kwargs)
231

232
233
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
234
235
                    slicer=_equal_slicer,
                    **parsed_kwargs)
236

237
238
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
239
240
                    slicer=_fftw_slicer,
                    **parsed_kwargs)
Ultima's avatar
Ultima committed
241
242
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
243
244
                    slicer=_freeform_slicer,
                    **parsed_kwargs)
245
246

            self.distributor_store[hashed_kwargs] = produced_distributor
Ultima's avatar
Ultima committed
247
            return self.distributor_store[hashed_kwargs]
248
249


250
distributor_factory = _distributor_factory()
Ultima's avatar
Ultima committed
251
252
253
254
255
256


def _infer_key_type(key):
    if key is None:
        return (None, None)
    found_boolean = False
257
    # Check which case we got:
258
    if isinstance(key, tuple) or isinstance(key, slice) or np.isscalar(key):
259
260
        # Check if there is something different in the array than
        # scalars and slices
Ultima's avatar
Ultima committed
261
262
        if isinstance(key, slice) or np.isscalar(key):
            key = [key]
263

Ultima's avatar
Ultima committed
264
265
266
        scalarQ = np.array(map(np.isscalar, key))
        sliceQ = np.array(map(lambda z: isinstance(z, slice), key))
        if np.all(scalarQ + sliceQ):
267
            found = 'slicetuple'
Ultima's avatar
Ultima committed
268
269
270
271
272
273
274
275
276
277
        else:
            found = 'indexinglist'
    elif isinstance(key, np.ndarray):
        found = 'ndarray'
        found_boolean = (key.dtype.type == np.bool_)
    elif isinstance(key, distributed_data_object):
        found = 'd2o'
        found_boolean = (key.dtype == np.bool_)
    elif isinstance(key, list):
        found = 'indexinglist'
278
279
    else:
        raise ValueError(about._errors.cstring("ERROR: Unknown keytype!"))
Ultima's avatar
Ultima committed
280
281
282
283
    return (found, found_boolean)


class distributor(object):
284
285

    def disperse_data(self, data, to_key, data_update, from_key=None,
Ultima's avatar
Ultima committed
286
                      local_keys=False, copy=True, **kwargs):
287
        # Check which keys we got:
Ultima's avatar
Ultima committed
288
289
290
        (to_found, to_found_boolean) = _infer_key_type(to_key)
        (from_found, from_found_boolean) = _infer_key_type(from_key)

291
        comm = self.comm
292
293
294
295
296
297
298
299
300
301
302
303
        if local_keys is False:
            return self._disperse_data_primitive(
                                         data=data,
                                         to_key=to_key,
                                         data_update=data_update,
                                         from_key=from_key,
                                         copy=copy,
                                         to_found=to_found,
                                         to_found_boolean=to_found_boolean,
                                         from_found=from_found,
                                         from_found_boolean=from_found_boolean,
                                         **kwargs)
304

Ultima's avatar
Ultima committed
305
        else:
306
            # assert that all to_keys are from same type
Ultima's avatar
Ultima committed
307
            to_found_list = comm.allgather(to_found)
308
            assert(all(x == to_found_list[0] for x in to_found_list))
Ultima's avatar
Ultima committed
309
            to_found_boolean_list = comm.allgather(to_found_boolean)
310
311
            assert(all(x == to_found_boolean_list[0] for x in
                       to_found_boolean_list))
Ultima's avatar
Ultima committed
312
            from_found_list = comm.allgather(from_found)
313
            assert(all(x == from_found_list[0] for x in from_found_list))
Ultima's avatar
Ultima committed
314
            from_found_boolean_list = comm.allgather(from_found_boolean)
315
316
            assert(all(x == from_found_boolean_list[0] for
                       x in from_found_boolean_list))
Ultima's avatar
Ultima committed
317

318
319
320
            # gather the local to_keys into a global to_key_list
            # Case 1: the to_keys are not distributed_data_objects
            # -> allgather does the job
Ultima's avatar
Ultima committed
321
322
            if to_found != 'd2o':
                to_key_list = comm.allgather(to_key)
323
324
325
            # Case 2: if the to_keys are distributed_data_objects, gather
            # the index of the array and build the to_key_list with help
            # from the librarian
Ultima's avatar
Ultima committed
326
327
328
329
            else:
                to_index_list = comm.allgather(to_key.index)
                to_key_list = map(lambda z: d2o_librarian[z], to_index_list)

330
            # gather the local from_keys. It is the same procedure as above
Ultima's avatar
Ultima committed
331
            if from_found != 'd2o':
332
                from_key_list = comm.allgather(from_key)
Ultima's avatar
Ultima committed
333
334
            else:
                from_index_list = comm.allgather(from_key.index)
335
336
                from_key_list = map(lambda z: d2o_librarian[z],
                                    from_index_list)
337

Ultima's avatar
Ultima committed
338
            local_data_update_is_scalar = np.isscalar(data_update)
339
            local_scalar_list = comm.allgather(local_data_update_is_scalar)
Ultima's avatar
Ultima committed
340
            for i in xrange(len(to_key_list)):
341
                if np.all(np.array(local_scalar_list) == True):
Ultima's avatar
Ultima committed
342
343
344
345
346
347
                    scalar_list = comm.allgather(data_update)
                    temp_data_update = scalar_list[i]
                elif isinstance(data_update, distributed_data_object):
                    data_update_index_list = comm.allgather(data_update.index)
                    data_update_list = map(lambda z: d2o_librarian[z],
                                           data_update_index_list)
348
                    temp_data_update = data_update_list[i]
Ultima's avatar
Ultima committed
349
                else:
350
351
                    # build a temporary freeform d2o which only contains data
                    # from node i
Ultima's avatar
Ultima committed
352
353
354
                    if comm.rank == i:
                        temp_shape = np.shape(data_update)
                        try:
355
                            temp_dtype = np.dtype(data_update.dtype)
Ultima's avatar
Ultima committed
356
                        except(TypeError):
357
                            temp_dtype = np.array(data_update).dtype
Ultima's avatar
Ultima committed
358
359
360
361
362
                    else:
                        temp_shape = None
                        temp_dtype = None
                    temp_shape = comm.bcast(temp_shape, root=i)
                    temp_dtype = comm.bcast(temp_dtype, root=i)
363

Ultima's avatar
Ultima committed
364
365
366
367
                    if comm.rank != i:
                        temp_shape = list(temp_shape)
                        temp_shape[0] = 0
                        temp_shape = tuple(temp_shape)
368
                        temp_data = np.empty(temp_shape, dtype=temp_dtype)
Ultima's avatar
Ultima committed
369
370
371
                    else:
                        temp_data = data_update
                    temp_data_update = distributed_data_object(
372
373
374
375
                                        local_data=temp_data,
                                        distribution_strategy='freeform',
                                        copy=False,
                                        comm=self.comm)
Ultima's avatar
Ultima committed
376
                # disperse the data one after another
377
378
379
380
381
382
383
384
385
386
387
                self._disperse_data_primitive(
                                      data=data,
                                      to_key=to_key_list[i],
                                      data_update=temp_data_update,
                                      from_key=from_key_list[i],
                                      copy=copy,
                                      to_found=to_found,
                                      to_found_boolean=to_found_boolean,
                                      from_found=from_found,
                                      from_found_boolean=from_found_boolean,
                                      **kwargs)
388
389
                i += 1

390

Ultima's avatar
Ultima committed
391
class _slicing_distributor(distributor):
392
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
393

Ultima's avatar
Ultima committed
394
395
        self.comm = comm
        self.distribution_strategy = name
396
        self.dtype = np.dtype(dtype)
397

theos's avatar
theos committed
398
        self._my_dtype_converter = dtype_converter
399

ultimanet's avatar
ultimanet committed
400
        if not self._my_dtype_converter.known_np_Q(self.dtype):
401
            raise TypeError(about._errors.cstring(
402
403
                "ERROR: The datatype " + str(self.dtype.__repr__()) +
                " is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
404

405
        self.mpi_dtype = self._my_dtype_converter.to_mpi(self.dtype)
406
407

        self.slicer = slicer
408
        self._local_size = self.slicer(comm=comm, **remaining_parsed_kwargs)
409
        self.local_start = self._local_size[0]
410
        self.local_end = self._local_size[1]
Ultima's avatar
Ultima committed
411
        self.global_shape = self._local_size[2]
412
        self.global_dim = reduce(lambda x, y: x*y, self.global_shape)
413

414
        self.local_length = self.local_end - self.local_start
ultimanet's avatar
ultimanet committed
415
416
417
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
418
419
        comm.Allgather([np.array(self.local_dim, dtype=np.int), MPI.INT],
                       [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
420
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
421

422
423
424
425
        self.local_slice = np.array([self.local_start, self.local_end,
                                     self.local_length, self.local_dim,
                                     self.local_dim_offset],
                                    dtype=np.int)
426
        # collect all local_slices
427
428
429
        self.all_local_slices = np.empty((comm.size, 5), dtype=np.int)
        comm.Allgather([np.array((self.local_slice,), dtype=np.int), MPI.INT],
                       [self.all_local_slices, MPI.INT])
430
431

    def initialize_data(self, global_data, local_data, alias, path, hermitian,
Ultima's avatar
Ultima committed
432
                        copy, **kwargs):
433
        if 'h5py' in gdi and alias is not None:
434
            local_data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
435
            return (local_data, hermitian)
436
437

        if self.distribution_strategy in ['equal', 'fftw']:
Ultima's avatar
Ultima committed
438
            if np.isscalar(global_data):
439
                local_data = np.empty(self.local_shape, dtype=self.dtype)
440
                local_data.fill(global_data)
Ultima's avatar
Ultima committed
441
442
                hermitian = True
            else:
443
444
                local_data = self.distribute_data(data=global_data,
                                                  copy=copy)
Ultima's avatar
Ultima committed
445
        elif self.distribution_strategy in ['freeform']:
Ultima's avatar
Ultima committed
446
            if isinstance(global_data, distributed_data_object):
447
                local_data = global_data.get_local_data(copy=copy)
Ultima's avatar
Ultima committed
448
            elif np.isscalar(local_data):
449
                temp_local_data = np.empty(self.local_shape,
450
                                           dtype=self.dtype)
451
                temp_local_data.fill(local_data)
452
                local_data = temp_local_data
Ultima's avatar
Ultima committed
453
454
                hermitian = True
            elif local_data is None:
455
                local_data = np.empty(self.local_shape, dtype=self.dtype)
456
457
458
            elif isinstance(local_data, np.ndarray):
                local_data = local_data.astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
459
460
            else:
                local_data = np.array(local_data).astype(
461
                    self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
462
463
        else:
            raise TypeError(about._errors.cstring(
464
                "ERROR: Unknown istribution strategy"))
465
466
        return (local_data, hermitian)

467
    def globalize_flat_index(self, index):
468
        return int(index) + self.local_dim_offset
469

470
471
472
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
473
            raise TypeError(about._errors.cstring("ERROR: Length\
474
                of index tuple does not match the array's shape!"))
475
476
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
477
        # ensure that the globalized index list is within the bounds
478
        global_index_memory = globalized_index
479
        globalized_index = np.clip(globalized_index,
480
                                   -np.array(self.global_shape),
481
                                   np.array(self.global_shape) - 1)
482
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
483
            about.warnings.cprint("WARNING: Indices were clipped!")
484
485
        globalized_index = tuple(globalized_index)
        return globalized_index
486

487
    def _allgather(self, thing, comm=None):
Ultima's avatar
Ultima committed
488
        if comm is None:
489
            comm = self.comm
490
491
        gathered_things = comm.allgather(thing)
        return gathered_things
492

493
    def _Allreduce_helper(self, sendbuf, recvbuf, op):
Ultima's avatar
Ultima committed
494
495
496
497
        send_dtype = self._my_dtype_converter.to_mpi(sendbuf.dtype)
        recv_dtype = self._my_dtype_converter.to_mpi(recvbuf.dtype)
        self.comm.Allreduce([sendbuf, send_dtype],
                            [recvbuf, recv_dtype],
498
                            op=op)
Ultima's avatar
Ultima committed
499
500
        return recvbuf

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    def _selective_allreduce(self, data, op, bufferQ=False):
        size = self.comm.size
        rank = self.comm.rank

        if size == 1:
            result_data = data
        else:

            # infer which data should be included in the allreduce and if its
            # array data
            if data is None:
                got_array = np.array([0])
            elif not isinstance(data, np.ndarray):
                got_array = np.array([1])
            elif np.issubdtype(data.dtype, np.complexfloating):
                # MPI.MAX and MPI.MIN do not support complex data types
                got_array = np.array([2])
            else:
                got_array = np.array([3])

            got_array_list = np.empty(size, dtype=np.int)
            self.comm.Allgather([got_array, MPI.INT],
                                [got_array_list, MPI.INT])

            # get first node with non-None data
            try:
                start = next(i for i in xrange(size) if got_array_list[i] > 0)
            except(StopIteration):
                raise ValueError("ERROR: No process with non-None data.")

            # check if the Uppercase function can be used or not
            # -> check if op supports buffers and if we got real array-data
            if bufferQ and got_array[start] == 3:
                # Send the dtype and shape from the start process to the others
                (new_dtype,
                 new_shape) = self.comm.bcast((data.dtype,
                                               data.shape), root=start)
                mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
                if rank == start:
                    result_data = data
                else:
                    result_data = np.empty(new_shape, dtype=new_dtype)

                self.comm.Bcast([result_data, mpi_dtype], root=start)

                for i in xrange(start+1, size):
                    if got_array_list[i]:
                        if rank == i:
                            temp_data = data
                        else:
                            temp_data = np.empty(new_shape, dtype=new_dtype)
                        self.comm.Bcast([temp_data, mpi_dtype], root=i)
                        result_data = op(result_data, temp_data)

            else:
                result_data = self.comm.bcast(data, root=start)
                for i in xrange(start+1, size):
                    if got_array_list[i]:
                        temp_data = self.comm.bcast(data, root=i)
                        result_data = op(result_data, temp_data)
        return result_data

    def contraction_helper(self, parent, function, allow_empty_contractions,
                           axis=None, **kwargs):
565
566
567
568
569
570
571
572
573
574
575
576
        if axis == ():
            return parent.copy()

        old_shape = parent.shape
        axis = cast_axis_to_tuple(axis)
        if axis is None:
            new_shape = ()
        else:
            new_shape = tuple([old_shape[i] for i in xrange(len(old_shape))
                               if i not in axis])

        local_data = parent.data
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592

        # if all local data is empty and empty_contractions are forbidden
        # call function on the local_data in order to raise the right exception
        if self.global_dim == 0 and not allow_empty_contractions:
                # this shall raise an exception
                function(local_data, axis=axis, **kwargs)

        # do the contraction on the node's local data
        if self.local_dim == 0 and not allow_empty_contractions:
            # this case will only be reached if some nodes have data and some
            # not
            contracted_local_data = None
        else:
            # if local_dim == 0 but empty contractions will be allowed
            # this will be a `contraction neutral` array.
            contracted_local_data = function(local_data, axis=axis, **kwargs)
593
594
595
596

        # check if additional contraction along the first axis must be done
        if axis is None or 0 in axis:
            (mpi_op, bufferQ) = op_translate_dict[function]
597
598
599
600
            contracted_global_data = self._selective_allreduce(
                                        contracted_local_data,
                                        mpi_op,
                                        bufferQ)
601
602
            new_dist_strategy = 'not'
        else:
603
            contracted_global_data = contracted_local_data
604
            new_dist_strategy = parent.distribution_strategy
605
606

        new_dtype = contracted_global_data.dtype
607
608

        if new_shape == ():
609
            result = contracted_global_data
610
611
612
613
614
615
616
617
618
619
620
621
622
        else:
            # try to store the result in a distributed_data_object with the
            # distribution_strategy as parent
            result = parent.copy_empty(global_shape=new_shape,
                                       dtype=new_dtype,
                                       distribution_strategy=new_dist_strategy)

            # However, there are cases where the contracted data does not any
            # longer follow the prior distribution scheme.
            # Example: FFTW distribution on 4 MPI processes
            # Contracting (4, 4) to (4,).
            # (4, 4) was distributed (1, 4)...(1, 4)
            # (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
623
            if result.local_shape != contracted_global_data.shape:
624
                result = parent.copy_empty(
625
                                    local_shape=contracted_global_data.shape,
626
627
                                    dtype=new_dtype,
                                    distribution_strategy='freeform')
628
            result.set_local_data(contracted_global_data, copy=False)
629
630
631

        return result

632
    def distribute_data(self, data=None, alias=None,
633
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
634
        '''
635
        distribute data checks
ultimanet's avatar
ultimanet committed
636
637
638
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
639
640
641

        comm = self.comm

642
        if 'h5py' in gdi and alias is not None:
643
            data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
644
645
646

        local_data_available_Q = (data is not None)
        data_available_Q = np.array(comm.allgather(local_data_available_Q))
647

Ultima's avatar
Ultima committed
648
        if np.all(data_available_Q == False):
Ultimanet's avatar
Ultimanet committed
649
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
650
651
        # if all nodes got data, we assume that it is the right data and
        # store it individually.
Ultima's avatar
Ultima committed
652
653
        elif np.all(data_available_Q == True):
            if isinstance(data, distributed_data_object):
654
                temp_d2o = data.get_data((slice(self.local_start,
Ultima's avatar
Ultima committed
655
                                                self.local_end),),
656
657
658
659
                                         local_keys=True,
                                         copy=copy)
                return temp_d2o.get_local_data(copy=False).astype(self.dtype,
                                                                  copy=False)
660
            else:
Ultima's avatar
Ultima committed
661
                return data[self.local_start:self.local_end].astype(
662
663
                    self.dtype,
                    copy=copy)
ultimanet's avatar
ultimanet committed
664
        else:
Ultima's avatar
Ultima committed
665
666
            raise ValueError(
                "ERROR: distribute_data must get data on all nodes!")
667
668

    def _disperse_data_primitive(self, data, to_key, data_update, from_key,
669
670
                                 copy, to_found, to_found_boolean, from_found,
                                 from_found_boolean, **kwargs):
Ultima's avatar
Ultima committed
671
672
        if np.isscalar(data_update):
            from_key = None
673
674
675

        # Case 1: to_key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
676
677
        if to_found == 'slicetuple':
            if from_found == 'slicetuple':
678
679
680
681
682
683
                return self.disperse_data_to_slices(data=data,
                                                    to_slices=to_key,
                                                    data_update=data_update,
                                                    from_slices=from_key,
                                                    copy=copy,
                                                    **kwargs)
684
685
686
            else:
                if from_key is not None:
                    about.infos.cprint(
687
                        "INFO: Advanced injection is not available for this " +
688
                        "combination of to_key and from_key.")
689
690
691
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
692

693
694
695
696
697
                return self.disperse_data_to_slices(
                                            data=data,
                                            to_slices=to_key,
                                            data_update=prepared_data_update,
                                            copy=copy,
698
                                            **kwargs)
699
700

        # Case 2: key is an array
701
        elif (to_found == 'ndarray' or to_found == 'd2o'):
702
            # Case 2.1: The array is boolean.
703
            if to_found_boolean:
704
705
                if from_key is not None:
                    about.infos.cprint(
706
                        "INFO: Advanced injection is not available for this " +
707
                        "combination of to_key and from_key.")
708
709
710
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
711
712
713
714
715
716
                return self.disperse_data_to_bool(
                                              data=data,
                                              to_boolean_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
717
718
            # Case 2.2: The array is not boolean. Only 1-dimensional
            # advanced slicing is supported.
719
720
721
722
723
            else:
                if len(to_key.shape) != 1:
                    raise ValueError(about._errors.cstring(
                        "WARNING: Only one-dimensional advanced indexing " +
                        "is supported"))
724
                # Make a recursive call in order to trigger the 'list'-section
725
726
727
                return self.disperse_data(data=data, to_key=[to_key],
                                          data_update=data_update,
                                          from_key=from_key, copy=copy,
728
729
                                          **kwargs)

730
731
        # Case 3 : to_key is a list. This list is interpreted as
        # one-dimensional advanced indexing list.
732
733
734
        elif to_found == 'indexinglist':
            if from_key is not None:
                about.infos.cprint(
735
                    "INFO: Advanced injection is not available for this " +
736
                    "combination of to_key and from_key.")
737
738
739
                prepared_data_update = data_update[from_key]
            else:
                prepared_data_update = data_update
740
741
742
743
744
            return self.disperse_data_to_list(data=data,
                                              to_list_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
745
746

    def disperse_data_to_list(self, data, to_list_key, data_update,
747
                              copy=True, **kwargs):
748

749
750
        if to_list_key == []:
            return data
751

Ultima's avatar
Ultima committed
752
        local_to_list_key = self._advanced_index_decycler(to_list_key)
753
        return self._disperse_data_to_list_and_bool_helper(
754
755
756
757
758
            data=data,
            local_to_key=local_to_list_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
759
760

    def disperse_data_to_bool(self, data, to_boolean_key, data_update,
761
                              copy=True, **kwargs):
762
763
        # Extract the part of the to_boolean_key which corresponds to the
        # local data
764
765
        local_to_boolean_key = self.extract_local_data(to_boolean_key)
        return self._disperse_data_to_list_and_bool_helper(
766
767
768
769
770
            data=data,
            local_to_key=local_to_boolean_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
771

772
    def _disperse_data_to_list_and_bool_helper(self, data, local_to_key,
773
                                               data_update, copy, **kwargs):
774
775
        comm = self.comm
        rank = comm.rank
776
777
        # Infer the length and offset of the locally affected data
        locally_affected_data = data[local_to_key]
778
        data_length = np.shape(locally_affected_data)[0]
Ultima's avatar
Ultima committed
779
        data_length_list = comm.allgather(data_length)
780
781
782
        data_length_offset_list = np.append([0],
                                            np.cumsum(data_length_list)[:-1])

783
784
        # Update the local data object with its very own portion
        o = data_length_offset_list
785
        l = data_length
786

787
        if isinstance(data_update, distributed_data_object):
788
            local_data_update = data_update.get_data(
789
790
                                          slice(o[rank], o[rank] + l),
                                          local_keys=True
791
792
793
                                          ).get_local_data(copy=False)
            data[local_to_key] = local_data_update.astype(self.dtype,
                                                          copy=False)
Ultima's avatar
Ultima committed
794
795
796
        elif np.isscalar(data_update):
            data[local_to_key] = data_update
        else:
797
            data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
798
799
                                          copy=copy).astype(self.dtype,
                                                            copy=False)
800
801
802
        return data

    def disperse_data_to_slices(self, data, to_slices,
803
                                data_update, from_slices=None, copy=True):
804
805
806
807
        comm = self.comm
        (to_slices, sliceified) = self._sliceify(to_slices)

        # parse the to_slices object
808
809
810
811
812
        localized_to_start, localized_to_stop = self._backshift_and_decycle(
            to_slices[0], self.local_start, self.local_end,
            self.global_shape[0])
        local_to_slice = (slice(localized_to_start, localized_to_stop,
                                to_slices[0].step),) + to_slices[1:]
813
        local_to_slice_shape = data[local_to_slice].shape
814

Ultima's avatar
Ultima committed
815
816
817
        to_step = to_slices[0].step
        if to_step is None:
            to_step = 1
818
        elif to_step == 0:
819
            raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
820
821
                "ERROR: to_step size == 0!"))

822
823
824
825
        # Compute the offset of the data the individual node will take.
        # The offset is free of stepsizes. It is the offset in terms of
        # the purely transported data. If to_step < 0, the offset will
        # be calculated in reverse order
Ultima's avatar
Ultima committed
826
        order = np.sign(to_step)
827

Ultima's avatar
Ultima committed
828
        local_affected_data_length = local_to_slice_shape[0]
829
830
831
        local_affected_data_length_list = np.empty(comm.size, dtype=np.int)
        comm.Allgather(
            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],
832
            [local_affected_data_length_list, MPI.INT])
833
834
835
        local_affected_data_length_offset_list = np.append([0],
                                                           np.cumsum(
            local_affected_data_length_list[::order])[:-1])[::order]
Ultima's avatar
Ultima committed
836

837
        if np.isscalar(data_update):
Ultima's avatar
Ultima committed
838
839
            data[local_to_slice] = data_update
        else:
840
            # construct the locally adapted from_slice object
Ultima's avatar
Ultima committed
841
842
843
844
            r = comm.rank
            o = local_affected_data_length_offset_list
            l = local_affected_data_length

845
846
847
            data_update = self._enfold(data_update, sliceified)

            # parse the from_slices object
Ultima's avatar
Ultima committed
848
            if from_slices is None:
849
                from_slices = (slice(None, None, None),)
850
851
852
853
854
855
            (from_slices_start, from_slices_stop) = \
                self._backshift_and_decycle(
                                            slice_object=from_slices[0],
                                            shifted_start=0,
                                            shifted_stop=data_update.shape[0],
                                            global_length=data_update.shape[0])
Ultima's avatar
Ultima committed
856
            if from_slices_start is None:
857
858
859
                raise ValueError(about._errors.cstring(
                    "ERROR: _backshift_and_decycle should never return " +
                    "None for local_start!"))
860
861

            # parse the step sizes
862
            from_step = from_slices[0].step
Ultima's avatar
Ultima committed
863
            if from_step is None:
864
                from_step = 1
865
            elif from_step == 0:
866
                raise ValueError(about._errors.cstring(
867
                    "ERROR: from_step size == 0!"))
868

869
            localized_from_start = from_slices_start + from_step * o[r]
870
            localized_from_stop = localized_from_start + from_step * l
871
872
            if localized_from_stop < 0:
                localized_from_stop = None
873
874

            localized_from_slice = (slice(localized_from_start,
875
876
                                          localized_from_stop,
                                          from_step),)
877

878
            update_slice = localized_from_slice + from_slices[1:]
879
880

            if isinstance(data_update, distributed_data_object):
881
                selected_update = data_update.get_data(
882
                                 key=update_slice,
883
884
885
886
                                 local_keys=True)
                local_data_update = selected_update.get_local_data(copy=False)
                local_data_update = local_data_update.astype(self.dtype,
                                                             copy=False)
Ultima's avatar
Ultima committed
887
888
                if np.prod(np.shape(local_data_update)) != 0:
                    data[local_to_slice] = local_data_update
889
            # elif np.isscalar(data_update):
Ultima's avatar
Ultima committed
890
            #    data[local_to_slice] = data_update
891
892
            else:
                local_data_update = np.array(data_update)[update_slice]
Ultima's avatar
Ultima committed
893
                if np.prod(np.shape(local_data_update)) != 0:
894
895
                    data[local_to_slice] = np.array(
                                                local_data_update,
896
897
                                                copy=copy).astype(self.dtype,
                                                                  copy=False)