distributor_factory.py 89.1 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2

3
4
import numbers

ultimanet's avatar
ultimanet committed
5
import numpy as np
6

theos's avatar
theos committed
7
8
9
from nifty.keepers import about,\
                          global_configuration as gc,\
                          global_dependency_injector as gdi
10

theos's avatar
theos committed
11
from distributed_data_object import distributed_data_object
12

theos's avatar
theos committed
13
14
15
16
from d2o_iter import d2o_slicing_iter,\
                     d2o_not_iter
from d2o_librarian import d2o_librarian
from dtype_converter import dtype_converter
17
18
from cast_axis_to_tuple import cast_axis_to_tuple
from translate_to_mpi_operator import op_translate_dict
19

theos's avatar
theos committed
20
from strategies import STRATEGIES
21

theos's avatar
theos committed
22
23
24
MPI = gdi[gc['mpi_module']]
h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
ultimanet's avatar
ultimanet committed
25

26

27
class _distributor_factory(object):
28

29
30
    def __init__(self):
        self.distributor_store = {}
31

32
    def parse_kwargs(self, distribution_strategy, comm,
33
34
35
                     global_data=None, global_shape=None,
                     local_data=None, local_shape=None,
                     alias=None, path=None,
36
37
38
39
40
41
42
43
44
45
46
47
                     dtype=None, skip_parsing=False, **kwargs):

        if skip_parsing:
            return_dict = {'comm': comm,
                           'dtype': dtype,
                           'name': distribution_strategy
                           }
            if distribution_strategy in STRATEGIES['global']:
                return_dict['global_shape'] = global_shape
            elif distribution_strategy in STRATEGIES['local']:
                return_dict['local_shape'] = local_shape
            return return_dict
Ultima's avatar
Ultima committed
48

49
        return_dict = {}
50
51
52
53
54

        expensive_checks = gc['d2o_init_checks']

        # Parse the MPI communicator
        if comm is None:
Ultima's avatar
Ultima committed
55
            raise ValueError(about._errors.cstring(
56
57
58
59
60
61
62
63
64
65
66
                "ERROR: The distributor needs MPI-communicator object comm!"))
        else:
            return_dict['comm'] = comm

        if expensive_checks:
            # Check that all nodes got the same distribution_strategy
            strat_list = comm.allgather(distribution_strategy)
            if all(x == strat_list[0] for x in strat_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The distribution-strategy must be the same on " +
                    "all nodes!"))
67

68
        # Check for an hdf5 file and open it if given
69
        if 'h5py' in gdi and alias is not None:
70
71
72
            # set file path
            file_path = path if (path is not None) else alias
            # open hdf5 file
73
            if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
Ultima's avatar
Ultima committed
74
75
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
76
77
78
                f = h5py.File(file_path, 'r')
            # open alias in file
            dset = f[alias]
Ultima's avatar
Ultima committed
79
80
81
        else:
            dset = None

82
        # Parse the datatype
Ultima's avatar
Ultima committed
83
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
84
                (dset is not None):
85
            dtype = dset.dtype
86
87

        elif distribution_strategy in ['not', 'equal', 'fftw']:
Ultima's avatar
Ultima committed
88
            if dtype is None:
89
                if global_data is None:
Ultima's avatar
Ultima committed
90
91
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
92
                else:
Ultima's avatar
Ultima committed
93
                    try:
94
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
95
                    except(AttributeError):
96
97
98
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
99

100
        elif distribution_strategy in STRATEGIES['local']:
101
            if dtype is None:
Ultima's avatar
Ultima committed
102
103
104
                if isinstance(global_data, distributed_data_object):
                    dtype = global_data.dtype
                elif local_data is not None:
Ultima's avatar
Ultima committed
105
                    try:
106
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
107
                    except(AttributeError):
108
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
109
110
111
                else:
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
112

Ultima's avatar
Ultima committed
113
            else:
114
                dtype = np.dtype(dtype)
115
116
117
118
119
        if expensive_checks:
            dtype_list = comm.allgather(dtype)
            if all(x == dtype_list[0] for x in dtype_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The given dtype must be the same on all nodes!"))
Ultima's avatar
Ultima committed
120
        return_dict['dtype'] = dtype
121
122
123

        # Parse the shape
        # Case 1: global-type slicer
124
        if distribution_strategy in STRATEGIES['global']:
Ultima's avatar
Ultima committed
125
126
127
128
129
130
131
132
133
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
134
                    "global_shape nor hdf5 file supplied!"))
Ultima's avatar
Ultima committed
135
136
            if global_shape == ():
                raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
137
                    "ERROR: global_shape == () is not a valid shape!"))
138

139
140
141
142
143
144
145
            if expensive_checks:
                global_shape_list = comm.allgather(global_shape)
                if not all(x == global_shape_list[0]
                           for x in global_shape_list):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The global_shape must be the same on all " +
                        "nodes!"))
Ultima's avatar
Ultima committed
146
147
            return_dict['global_shape'] = global_shape

148
        # Case 2: local-type slicer
Ultima's avatar
Ultima committed
149
150
151
152
        elif distribution_strategy in ['freeform']:
            if isinstance(global_data, distributed_data_object):
                local_shape = global_data.local_shape
            elif local_data is not None and np.isscalar(local_data) == False:
Ultima's avatar
Ultima committed
153
154
155
156
157
158
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
159
                    "local_shape nor global d2o supplied!"))
Ultima's avatar
Ultima committed
160
161
            if local_shape == ():
                raise ValueError(about._errors.cstring(
162
163
                    "ERROR: local_shape == () is not a valid shape!"))

164
165
166
167
168
169
170
171
            if expensive_checks:
                local_shape_list = comm.allgather(local_shape[1:])
                cleared_set = set(local_shape_list)
                cleared_set.discard(())
                if len(cleared_set) > 1:
                    raise ValueError(about._errors.cstring(
                        "ERROR: All but the first entry of local_shape " +
                        "must be the same on all nodes!"))
Ultima's avatar
Ultima committed
172
173
            return_dict['local_shape'] = local_shape

174
        # Add the name of the distributor if needed
175
176
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
177
178

        # close the file-handle
Ultima's avatar
Ultima committed
179
180
181
        if dset is not None:
            f.close()

182
        return return_dict
183

Ultima's avatar
Ultima committed
184
    def hash_arguments(self, distribution_strategy, **kwargs):
185
        kwargs = kwargs.copy()
186

187
188
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
189

190
        if 'global_shape' in kwargs:
Ultima's avatar
Ultima committed
191
            kwargs['global_shape'] = kwargs['global_shape']
192
        if 'local_shape' in kwargs:
193
194
195
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
196

Ultima's avatar
Ultima committed
197
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
198
        kwargs['distribution_strategy'] = distribution_strategy
199

200
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
201

202
    def dictionize_np(self, x):
203
        dic = x.type.__dict__.items()
204
        if x is np.float:
205
            dic[24] = 0
206
207
            dic[29] = 0
            dic[37] = 0
208
209
        return frozenset(dic)

210
    def get_distributor(self, distribution_strategy, comm, **kwargs):
211
        # check if the distribution strategy is known
212
        if distribution_strategy not in STRATEGIES['all']:
213
            raise ValueError(about._errors.cstring(
214
                "ERROR: Unknown distribution strategy supplied."))
215
216

        # parse the kwargs
Ultima's avatar
Ultima committed
217
        parsed_kwargs = self.parse_kwargs(
218
219
220
            distribution_strategy=distribution_strategy,
            comm=comm,
            **kwargs)
221

Ultima's avatar
Ultima committed
222
223
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
224
        # check if the distributors has already been produced in the past
225
        if hashed_kwargs in self.distributor_store:
Ultima's avatar
Ultima committed
226
            return self.distributor_store[hashed_kwargs]
227
        else:
228
            # produce new distributor
229
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
230
                produced_distributor = _not_distributor(**parsed_kwargs)
231

232
233
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
234
235
                    slicer=_equal_slicer,
                    **parsed_kwargs)
236

237
238
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
239
240
                    slicer=_fftw_slicer,
                    **parsed_kwargs)
Ultima's avatar
Ultima committed
241
242
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
243
244
                    slicer=_freeform_slicer,
                    **parsed_kwargs)
245
246

            self.distributor_store[hashed_kwargs] = produced_distributor
Ultima's avatar
Ultima committed
247
            return self.distributor_store[hashed_kwargs]
248
249


250
distributor_factory = _distributor_factory()
Ultima's avatar
Ultima committed
251
252
253
254
255
256


def _infer_key_type(key):
    if key is None:
        return (None, None)
    found_boolean = False
257
    # Check which case we got:
258
    if isinstance(key, tuple) or isinstance(key, slice) or np.isscalar(key):
259
260
        # Check if there is something different in the array than
        # scalars and slices
Ultima's avatar
Ultima committed
261
262
        if isinstance(key, slice) or np.isscalar(key):
            key = [key]
263

Ultima's avatar
Ultima committed
264
265
266
        scalarQ = np.array(map(np.isscalar, key))
        sliceQ = np.array(map(lambda z: isinstance(z, slice), key))
        if np.all(scalarQ + sliceQ):
267
            found = 'slicetuple'
Ultima's avatar
Ultima committed
268
269
270
271
272
273
274
275
276
277
        else:
            found = 'indexinglist'
    elif isinstance(key, np.ndarray):
        found = 'ndarray'
        found_boolean = (key.dtype.type == np.bool_)
    elif isinstance(key, distributed_data_object):
        found = 'd2o'
        found_boolean = (key.dtype == np.bool_)
    elif isinstance(key, list):
        found = 'indexinglist'
278
279
    else:
        raise ValueError(about._errors.cstring("ERROR: Unknown keytype!"))
Ultima's avatar
Ultima committed
280
281
282
283
    return (found, found_boolean)


class distributor(object):
284
285

    def disperse_data(self, data, to_key, data_update, from_key=None,
Ultima's avatar
Ultima committed
286
                      local_keys=False, copy=True, **kwargs):
287
        # Check which keys we got:
Ultima's avatar
Ultima committed
288
289
290
        (to_found, to_found_boolean) = _infer_key_type(to_key)
        (from_found, from_found_boolean) = _infer_key_type(from_key)

291
        comm = self.comm
292
293
294
295
296
297
298
299
300
301
302
303
        if local_keys is False:
            return self._disperse_data_primitive(
                                         data=data,
                                         to_key=to_key,
                                         data_update=data_update,
                                         from_key=from_key,
                                         copy=copy,
                                         to_found=to_found,
                                         to_found_boolean=to_found_boolean,
                                         from_found=from_found,
                                         from_found_boolean=from_found_boolean,
                                         **kwargs)
304

Ultima's avatar
Ultima committed
305
        else:
306
            # assert that all to_keys are from same type
Ultima's avatar
Ultima committed
307
            to_found_list = comm.allgather(to_found)
308
            assert(all(x == to_found_list[0] for x in to_found_list))
Ultima's avatar
Ultima committed
309
            to_found_boolean_list = comm.allgather(to_found_boolean)
310
311
            assert(all(x == to_found_boolean_list[0] for x in
                       to_found_boolean_list))
Ultima's avatar
Ultima committed
312
            from_found_list = comm.allgather(from_found)
313
            assert(all(x == from_found_list[0] for x in from_found_list))
Ultima's avatar
Ultima committed
314
            from_found_boolean_list = comm.allgather(from_found_boolean)
315
316
            assert(all(x == from_found_boolean_list[0] for
                       x in from_found_boolean_list))
Ultima's avatar
Ultima committed
317

318
319
320
            # gather the local to_keys into a global to_key_list
            # Case 1: the to_keys are not distributed_data_objects
            # -> allgather does the job
Ultima's avatar
Ultima committed
321
322
            if to_found != 'd2o':
                to_key_list = comm.allgather(to_key)
323
324
325
            # Case 2: if the to_keys are distributed_data_objects, gather
            # the index of the array and build the to_key_list with help
            # from the librarian
Ultima's avatar
Ultima committed
326
327
328
329
            else:
                to_index_list = comm.allgather(to_key.index)
                to_key_list = map(lambda z: d2o_librarian[z], to_index_list)

330
            # gather the local from_keys. It is the same procedure as above
Ultima's avatar
Ultima committed
331
            if from_found != 'd2o':
332
                from_key_list = comm.allgather(from_key)
Ultima's avatar
Ultima committed
333
334
            else:
                from_index_list = comm.allgather(from_key.index)
335
336
                from_key_list = map(lambda z: d2o_librarian[z],
                                    from_index_list)
337

Ultima's avatar
Ultima committed
338
            local_data_update_is_scalar = np.isscalar(data_update)
339
            local_scalar_list = comm.allgather(local_data_update_is_scalar)
Ultima's avatar
Ultima committed
340
            for i in xrange(len(to_key_list)):
341
                if np.all(np.array(local_scalar_list) == True):
Ultima's avatar
Ultima committed
342
343
344
345
346
347
                    scalar_list = comm.allgather(data_update)
                    temp_data_update = scalar_list[i]
                elif isinstance(data_update, distributed_data_object):
                    data_update_index_list = comm.allgather(data_update.index)
                    data_update_list = map(lambda z: d2o_librarian[z],
                                           data_update_index_list)
348
                    temp_data_update = data_update_list[i]
Ultima's avatar
Ultima committed
349
                else:
350
351
                    # build a temporary freeform d2o which only contains data
                    # from node i
Ultima's avatar
Ultima committed
352
353
354
                    if comm.rank == i:
                        temp_shape = np.shape(data_update)
                        try:
355
                            temp_dtype = np.dtype(data_update.dtype)
Ultima's avatar
Ultima committed
356
                        except(TypeError):
357
                            temp_dtype = np.array(data_update).dtype
Ultima's avatar
Ultima committed
358
359
360
361
362
                    else:
                        temp_shape = None
                        temp_dtype = None
                    temp_shape = comm.bcast(temp_shape, root=i)
                    temp_dtype = comm.bcast(temp_dtype, root=i)
363

Ultima's avatar
Ultima committed
364
365
366
367
                    if comm.rank != i:
                        temp_shape = list(temp_shape)
                        temp_shape[0] = 0
                        temp_shape = tuple(temp_shape)
368
                        temp_data = np.empty(temp_shape, dtype=temp_dtype)
Ultima's avatar
Ultima committed
369
370
371
                    else:
                        temp_data = data_update
                    temp_data_update = distributed_data_object(
372
373
374
375
                                        local_data=temp_data,
                                        distribution_strategy='freeform',
                                        copy=False,
                                        comm=self.comm)
Ultima's avatar
Ultima committed
376
                # disperse the data one after another
377
378
379
380
381
382
383
384
385
386
387
                self._disperse_data_primitive(
                                      data=data,
                                      to_key=to_key_list[i],
                                      data_update=temp_data_update,
                                      from_key=from_key_list[i],
                                      copy=copy,
                                      to_found=to_found,
                                      to_found_boolean=to_found_boolean,
                                      from_found=from_found,
                                      from_found_boolean=from_found_boolean,
                                      **kwargs)
388
389
                i += 1

390

Ultima's avatar
Ultima committed
391
class _slicing_distributor(distributor):
392
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
393

Ultima's avatar
Ultima committed
394
395
        self.comm = comm
        self.distribution_strategy = name
396
        self.dtype = np.dtype(dtype)
397

theos's avatar
theos committed
398
        self._my_dtype_converter = dtype_converter
399

ultimanet's avatar
ultimanet committed
400
        if not self._my_dtype_converter.known_np_Q(self.dtype):
401
            raise TypeError(about._errors.cstring(
402
403
                "ERROR: The datatype " + str(self.dtype.__repr__()) +
                " is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
404

405
        self.mpi_dtype = self._my_dtype_converter.to_mpi(self.dtype)
406
407

        self.slicer = slicer
408
        self._local_size = self.slicer(comm=comm, **remaining_parsed_kwargs)
409
        self.local_start = self._local_size[0]
410
        self.local_end = self._local_size[1]
Ultima's avatar
Ultima committed
411
        self.global_shape = self._local_size[2]
412
        self.global_dim = reduce(lambda x, y: x*y, self.global_shape)
413

414
        self.local_length = self.local_end - self.local_start
ultimanet's avatar
ultimanet committed
415
416
417
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
418
419
        comm.Allgather([np.array(self.local_dim, dtype=np.int), MPI.INT],
                       [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
420
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
421

422
423
424
425
        self.local_slice = np.array([self.local_start, self.local_end,
                                     self.local_length, self.local_dim,
                                     self.local_dim_offset],
                                    dtype=np.int)
426
        # collect all local_slices
427
428
429
        self.all_local_slices = np.empty((comm.size, 5), dtype=np.int)
        comm.Allgather([np.array((self.local_slice,), dtype=np.int), MPI.INT],
                       [self.all_local_slices, MPI.INT])
430
431

    def initialize_data(self, global_data, local_data, alias, path, hermitian,
Ultima's avatar
Ultima committed
432
                        copy, **kwargs):
433
        if 'h5py' in gdi and alias is not None:
434
            local_data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
435
            return (local_data, hermitian)
436
437

        if self.distribution_strategy in ['equal', 'fftw']:
Ultima's avatar
Ultima committed
438
            if np.isscalar(global_data):
439
                local_data = np.empty(self.local_shape, dtype=self.dtype)
440
                local_data.fill(global_data)
Ultima's avatar
Ultima committed
441
442
                hermitian = True
            else:
443
444
                local_data = self.distribute_data(data=global_data,
                                                  copy=copy)
Ultima's avatar
Ultima committed
445
        elif self.distribution_strategy in ['freeform']:
Ultima's avatar
Ultima committed
446
            if isinstance(global_data, distributed_data_object):
447
                local_data = global_data.get_local_data(copy=copy)
Ultima's avatar
Ultima committed
448
            elif np.isscalar(local_data):
449
                temp_local_data = np.empty(self.local_shape,
450
                                           dtype=self.dtype)
451
                temp_local_data.fill(local_data)
452
                local_data = temp_local_data
Ultima's avatar
Ultima committed
453
454
                hermitian = True
            elif local_data is None:
455
                local_data = np.empty(self.local_shape, dtype=self.dtype)
456
457
458
            elif isinstance(local_data, np.ndarray):
                local_data = local_data.astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
459
460
            else:
                local_data = np.array(local_data).astype(
461
                    self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
462
463
        else:
            raise TypeError(about._errors.cstring(
464
                "ERROR: Unknown istribution strategy"))
465
466
        return (local_data, hermitian)

467
    def globalize_flat_index(self, index):
468
        return int(index) + self.local_dim_offset
469

470
471
472
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
473
            raise TypeError(about._errors.cstring("ERROR: Length\
474
                of index tuple does not match the array's shape!"))
475
476
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
477
        # ensure that the globalized index list is within the bounds
478
        global_index_memory = globalized_index
479
        globalized_index = np.clip(globalized_index,
480
                                   -np.array(self.global_shape),
481
                                   np.array(self.global_shape) - 1)
482
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
483
            about.warnings.cprint("WARNING: Indices were clipped!")
484
485
        globalized_index = tuple(globalized_index)
        return globalized_index
486

487
    def _allgather(self, thing, comm=None):
Ultima's avatar
Ultima committed
488
        if comm is None:
489
            comm = self.comm
490
491
        gathered_things = comm.allgather(thing)
        return gathered_things
492

493
    def _Allreduce_helper(self, sendbuf, recvbuf, op):
Ultima's avatar
Ultima committed
494
495
496
497
        send_dtype = self._my_dtype_converter.to_mpi(sendbuf.dtype)
        recv_dtype = self._my_dtype_converter.to_mpi(recvbuf.dtype)
        self.comm.Allreduce([sendbuf, send_dtype],
                            [recvbuf, recv_dtype],
498
                            op=op)
Ultima's avatar
Ultima committed
499
500
        return recvbuf

501
502
503
504
505
    def _selective_allreduce(self, data, op, bufferQ=False):
        size = self.comm.size
        rank = self.comm.rank

        if size == 1:
506
507
            if data is None:
                raise ValueError("ERROR: No process with non-None data.")
508
509
510
511
512
513
514
515
            result_data = data
        else:

            # infer which data should be included in the allreduce and if its
            # array data
            if data is None:
                got_array = np.array([0])
            elif not isinstance(data, np.ndarray):
516
517
                got_array = np.array([2])
            elif reduce(lambda x, y: x*y, data.shape) == 0:
518
519
520
521
                got_array = np.array([1])
            elif np.issubdtype(data.dtype, np.complexfloating):
                # MPI.MAX and MPI.MIN do not support complex data types
                got_array = np.array([3])
522
523
            else:
                got_array = np.array([4])
524
525
526
527
528

            got_array_list = np.empty(size, dtype=np.int)
            self.comm.Allgather([got_array, MPI.INT],
                                [got_array_list, MPI.INT])

529
530
531
            if reduce(lambda x, y: x & y, got_array_list == 1):
                return data

532
533
            # get first node with non-None data
            try:
534
                start = next(i for i in xrange(size) if got_array_list[i] > 1)
535
536
537
538
539
            except(StopIteration):
                raise ValueError("ERROR: No process with non-None data.")

            # check if the Uppercase function can be used or not
            # -> check if op supports buffers and if we got real array-data
540
            if bufferQ and got_array[start] == 4:
541
542
543
544
545
546
547
548
549
550
551
552
553
                # Send the dtype and shape from the start process to the others
                (new_dtype,
                 new_shape) = self.comm.bcast((data.dtype,
                                               data.shape), root=start)
                mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
                if rank == start:
                    result_data = data
                else:
                    result_data = np.empty(new_shape, dtype=new_dtype)

                self.comm.Bcast([result_data, mpi_dtype], root=start)

                for i in xrange(start+1, size):
554
                    if got_array_list[i] > 1:
555
556
557
558
559
560
561
562
563
564
                        if rank == i:
                            temp_data = data
                        else:
                            temp_data = np.empty(new_shape, dtype=new_dtype)
                        self.comm.Bcast([temp_data, mpi_dtype], root=i)
                        result_data = op(result_data, temp_data)

            else:
                result_data = self.comm.bcast(data, root=start)
                for i in xrange(start+1, size):
565
                    if got_array_list[i] > 1:
566
567
568
569
570
571
                        temp_data = self.comm.bcast(data, root=i)
                        result_data = op(result_data, temp_data)
        return result_data

    def contraction_helper(self, parent, function, allow_empty_contractions,
                           axis=None, **kwargs):
572
573
574
575
576
577
578
579
580
581
582
583
        if axis == ():
            return parent.copy()

        old_shape = parent.shape
        axis = cast_axis_to_tuple(axis)
        if axis is None:
            new_shape = ()
        else:
            new_shape = tuple([old_shape[i] for i in xrange(len(old_shape))
                               if i not in axis])

        local_data = parent.data
584

585
        try:
586
            contracted_local_data = function(local_data, axis=axis, **kwargs)
587
588
        except(ValueError):
            contracted_local_data = None
589
590
591
592

        # check if additional contraction along the first axis must be done
        if axis is None or 0 in axis:
            (mpi_op, bufferQ) = op_translate_dict[function]
593
594
595
596
            contracted_global_data = self._selective_allreduce(
                                        contracted_local_data,
                                        mpi_op,
                                        bufferQ)
597
598
            new_dist_strategy = 'not'
        else:
599
600
601
            if contracted_local_data is None:
                # raise the exception implicitly
                function(local_data, axis=axis, **kwargs)
602
            contracted_global_data = contracted_local_data
603
            new_dist_strategy = parent.distribution_strategy
604
605

        new_dtype = contracted_global_data.dtype
606
607

        if new_shape == ():
608
            result = contracted_global_data
609
610
611
612
613
614
615
616
617
618
619
620
621
        else:
            # try to store the result in a distributed_data_object with the
            # distribution_strategy as parent
            result = parent.copy_empty(global_shape=new_shape,
                                       dtype=new_dtype,
                                       distribution_strategy=new_dist_strategy)

            # However, there are cases where the contracted data does not any
            # longer follow the prior distribution scheme.
            # Example: FFTW distribution on 4 MPI processes
            # Contracting (4, 4) to (4,).
            # (4, 4) was distributed (1, 4)...(1, 4)
            # (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
622
            if result.local_shape != contracted_global_data.shape:
623
                result = parent.copy_empty(
624
                                    local_shape=contracted_global_data.shape,
625
626
                                    dtype=new_dtype,
                                    distribution_strategy='freeform')
627
            result.set_local_data(contracted_global_data, copy=False)
628
629
630

        return result

631
    def distribute_data(self, data=None, alias=None,
632
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
633
        '''
634
        distribute data checks
ultimanet's avatar
ultimanet committed
635
636
637
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
638
639
640

        comm = self.comm

641
        if 'h5py' in gdi and alias is not None:
642
            data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
643
644
645

        local_data_available_Q = (data is not None)
        data_available_Q = np.array(comm.allgather(local_data_available_Q))
646

Ultima's avatar
Ultima committed
647
        if np.all(data_available_Q == False):
Ultimanet's avatar
Ultimanet committed
648
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
649
650
        # if all nodes got data, we assume that it is the right data and
        # store it individually.
Ultima's avatar
Ultima committed
651
652
        elif np.all(data_available_Q == True):
            if isinstance(data, distributed_data_object):
653
                temp_d2o = data.get_data((slice(self.local_start,
Ultima's avatar
Ultima committed
654
                                                self.local_end),),
655
656
657
658
                                         local_keys=True,
                                         copy=copy)
                return temp_d2o.get_local_data(copy=False).astype(self.dtype,
                                                                  copy=False)
659
            else:
Ultima's avatar
Ultima committed
660
                return data[self.local_start:self.local_end].astype(
661
662
                    self.dtype,
                    copy=copy)
ultimanet's avatar
ultimanet committed
663
        else:
Ultima's avatar
Ultima committed
664
665
            raise ValueError(
                "ERROR: distribute_data must get data on all nodes!")
666
667

    def _disperse_data_primitive(self, data, to_key, data_update, from_key,
668
669
                                 copy, to_found, to_found_boolean, from_found,
                                 from_found_boolean, **kwargs):
Ultima's avatar
Ultima committed
670
671
        if np.isscalar(data_update):
            from_key = None
672
673
674

        # Case 1: to_key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
675
676
        if to_found == 'slicetuple':
            if from_found == 'slicetuple':
677
678
679
680
681
682
                return self.disperse_data_to_slices(data=data,
                                                    to_slices=to_key,
                                                    data_update=data_update,
                                                    from_slices=from_key,
                                                    copy=copy,
                                                    **kwargs)
683
684
685
            else:
                if from_key is not None:
                    about.infos.cprint(
686
                        "INFO: Advanced injection is not available for this " +
687
                        "combination of to_key and from_key.")
688
689
690
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
691

692
693
694
695
696
                return self.disperse_data_to_slices(
                                            data=data,
                                            to_slices=to_key,
                                            data_update=prepared_data_update,
                                            copy=copy,
697
                                            **kwargs)
698
699

        # Case 2: key is an array
700
        elif (to_found == 'ndarray' or to_found == 'd2o'):
701
            # Case 2.1: The array is boolean.
702
            if to_found_boolean:
703
704
                if from_key is not None:
                    about.infos.cprint(
705
                        "INFO: Advanced injection is not available for this " +
706
                        "combination of to_key and from_key.")
707
708
709
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
710
711
712
713
714
715
                return self.disperse_data_to_bool(
                                              data=data,
                                              to_boolean_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
716
717
            # Case 2.2: The array is not boolean. Only 1-dimensional
            # advanced slicing is supported.
718
719
720
721
722
            else:
                if len(to_key.shape) != 1:
                    raise ValueError(about._errors.cstring(
                        "WARNING: Only one-dimensional advanced indexing " +
                        "is supported"))
723
                # Make a recursive call in order to trigger the 'list'-section
724
725
726
                return self.disperse_data(data=data, to_key=[to_key],
                                          data_update=data_update,
                                          from_key=from_key, copy=copy,
727
728
                                          **kwargs)

729
730
        # Case 3 : to_key is a list. This list is interpreted as
        # one-dimensional advanced indexing list.
731
732
733
        elif to_found == 'indexinglist':
            if from_key is not None:
                about.infos.cprint(
734
                    "INFO: Advanced injection is not available for this " +
735
                    "combination of to_key and from_key.")
736
737
738
                prepared_data_update = data_update[from_key]
            else:
                prepared_data_update = data_update
739
740
741
742
743
            return self.disperse_data_to_list(data=data,
                                              to_list_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
744
745

    def disperse_data_to_list(self, data, to_list_key, data_update,
746
                              copy=True, **kwargs):
747

748
749
        if to_list_key == []:
            return data
750

Ultima's avatar
Ultima committed
751
        local_to_list_key = self._advanced_index_decycler(to_list_key)
752
        return self._disperse_data_to_list_and_bool_helper(
753
754
755
756
757
            data=data,
            local_to_key=local_to_list_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
758
759

    def disperse_data_to_bool(self, data, to_boolean_key, data_update,
760
                              copy=True, **kwargs):
761
762
        # Extract the part of the to_boolean_key which corresponds to the
        # local data
763
764
        local_to_boolean_key = self.extract_local_data(to_boolean_key)
        return self._disperse_data_to_list_and_bool_helper(
765
766
767
768
769
            data=data,
            local_to_key=local_to_boolean_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
770

771
    def _disperse_data_to_list_and_bool_helper(self, data, local_to_key,
772
                                               data_update, copy, **kwargs):
773
774
        comm = self.comm
        rank = comm.rank
775
776
        # Infer the length and offset of the locally affected data
        locally_affected_data = data[local_to_key]
777
        data_length = np.shape(locally_affected_data)[0]
Ultima's avatar
Ultima committed
778
        data_length_list = comm.allgather(data_length)
779
780
781
        data_length_offset_list = np.append([0],
                                            np.cumsum(data_length_list)[:-1])

782
783
        # Update the local data object with its very own portion
        o = data_length_offset_list
784
        l = data_length
785

786
        if isinstance(data_update, distributed_data_object):
787
            local_data_update = data_update.get_data(
788
789
                                          slice(o[rank], o[rank] + l),
                                          local_keys=True
790
791
792
                                          ).get_local_data(copy=False)
            data[local_to_key] = local_data_update.astype(self.dtype,
                                                          copy=False)
Ultima's avatar
Ultima committed
793
794
795
        elif np.isscalar(data_update):
            data[local_to_key] = data_update
        else:
796
            data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
797
798
                                          copy=copy).astype(self.dtype,
                                                            copy=False)
799
800
801
        return data

    def disperse_data_to_slices(self, data, to_slices,
802
                                data_update, from_slices=None, copy=True):
803
804
805
806
        comm = self.comm
        (to_slices, sliceified) = self._sliceify(to_slices)

        # parse the to_slices object
807
808
809
810
811
        localized_to_start, localized_to_stop = self._backshift_and_decycle(
            to_slices[0], self.local_start, self.local_end,
            self.global_shape[0])
        local_to_slice = (slice(localized_to_start, localized_to_stop,
                                to_slices[0].step),) + to_slices[1:]
812
        local_to_slice_shape = data[local_to_slice].shape
813

Ultima's avatar
Ultima committed
814
815
816
        to_step = to_slices[0].step
        if to_step is None:
            to_step = 1
817
        elif to_step == 0:
818
            raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
819
820
                "ERROR: to_step size == 0!"))

821
822
823
824
        # Compute the offset of the data the individual node will take.
        # The offset is free of stepsizes. It is the offset in terms of
        # the purely transported data. If to_step < 0, the offset will
        # be calculated in reverse order
Ultima's avatar
Ultima committed
825
        order = np.sign(to_step)
826

Ultima's avatar
Ultima committed
827
        local_affected_data_length = local_to_slice_shape[0]
828
829
830
        local_affected_data_length_list = np.empty(comm.size, dtype=np.int)
        comm.Allgather(
            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],
831
            [local_affected_data_length_list, MPI.INT])
832
833
834
        local_affected_data_length_offset_list = np.append([0],
                                                           np.cumsum(
            local_affected_data_length_list[::order])[:-1])[::order]
Ultima's avatar
Ultima committed
835

836
        if np.isscalar(data_update):
Ultima's avatar
Ultima committed
837
838
            data[local_to_slice] = data_update
        else:
839
            # construct the locally adapted from_slice object
Ultima's avatar
Ultima committed
840
841
842
843
            r = comm.rank
            o = local_affected_data_length_offset_list
            l = local_affected_data_length

844
845
846
            data_update = self._enfold(data_update, sliceified)

            # parse the from_slices object
Ultima's avatar
Ultima committed
847
            if from_slices is None:
848
                from_slices = (slice(None, None, None),)
849
850
851
852
853
854
            (from_slices_start, from_slices_stop) = \
                self._backshift_and_decycle(
                                            slice_object=from_slices[0],
                                            shifted_start=0,
                                            shifted_stop=data_update.shape[0],
                                            global_length=data_update.shape[0])
Ultima's avatar
Ultima committed
855
            if from_slices_start is None:
856
857
858
                raise ValueError(about._errors.cstring(
                    "ERROR: _backshift_and_decycle should never return " +
                    "None for local_start!"))
859
860

            # parse the step sizes
861
            from_step = from_slices[0].step
Ultima's avatar
Ultima committed
862
            if from_step is None:
863
                from_step = 1
864
            elif from_step == 0:
865
                raise ValueError(about._errors.cstring(
866
                    "ERROR: from_step size == 0!"))
867

868
            localized_from_start = from_slices_start + from_step * o[r]
869
            localized_from_stop = localized_from_start + from_step * l
870
871
            if localized_from_stop < 0:
                localized_from_stop = None