distributor_factory.py 80.2 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2

ultimanet's avatar
ultimanet committed
3
import numpy as np
4

theos's avatar
theos committed
5
6
7
from nifty.keepers import about,\
                          global_configuration as gc,\
                          global_dependency_injector as gdi
8

theos's avatar
theos committed
9
from distributed_data_object import distributed_data_object
10

theos's avatar
theos committed
11
12
13
14
from d2o_iter import d2o_slicing_iter,\
                     d2o_not_iter
from d2o_librarian import d2o_librarian
from dtype_converter import dtype_converter
15

theos's avatar
theos committed
16
from strategies import STRATEGIES
17

theos's avatar
theos committed
18
19
20
MPI = gdi[gc['mpi_module']]
h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
ultimanet's avatar
ultimanet committed
21

22

23
class _distributor_factory(object):
24

25
26
    def __init__(self):
        self.distributor_store = {}
27

28
    def parse_kwargs(self, distribution_strategy, comm,
29
30
31
                     global_data=None, global_shape=None,
                     local_data=None, local_shape=None,
                     alias=None, path=None,
32
33
34
35
36
37
38
39
40
41
42
43
                     dtype=None, skip_parsing=False, **kwargs):

        if skip_parsing:
            return_dict = {'comm': comm,
                           'dtype': dtype,
                           'name': distribution_strategy
                           }
            if distribution_strategy in STRATEGIES['global']:
                return_dict['global_shape'] = global_shape
            elif distribution_strategy in STRATEGIES['local']:
                return_dict['local_shape'] = local_shape
            return return_dict
Ultima's avatar
Ultima committed
44

45
        return_dict = {}
46
47
48
49
50

        expensive_checks = gc['d2o_init_checks']

        # Parse the MPI communicator
        if comm is None:
Ultima's avatar
Ultima committed
51
            raise ValueError(about._errors.cstring(
52
53
54
55
56
57
58
59
60
61
62
                "ERROR: The distributor needs MPI-communicator object comm!"))
        else:
            return_dict['comm'] = comm

        if expensive_checks:
            # Check that all nodes got the same distribution_strategy
            strat_list = comm.allgather(distribution_strategy)
            if all(x == strat_list[0] for x in strat_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The distribution-strategy must be the same on " +
                    "all nodes!"))
63

64
        # Check for an hdf5 file and open it if given
65
        if 'h5py' in gdi and alias is not None:
66
67
68
            # set file path
            file_path = path if (path is not None) else alias
            # open hdf5 file
69
            if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
Ultima's avatar
Ultima committed
70
71
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
72
73
74
                f = h5py.File(file_path, 'r')
            # open alias in file
            dset = f[alias]
Ultima's avatar
Ultima committed
75
76
77
        else:
            dset = None

78
        # Parse the datatype
Ultima's avatar
Ultima committed
79
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
80
                (dset is not None):
81
            dtype = dset.dtype
82
83

        elif distribution_strategy in ['not', 'equal', 'fftw']:
Ultima's avatar
Ultima committed
84
            if dtype is None:
85
                if global_data is None:
Ultima's avatar
Ultima committed
86
87
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
88
                else:
Ultima's avatar
Ultima committed
89
                    try:
90
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
91
                    except(AttributeError):
92
93
94
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
95

96
        elif distribution_strategy in STRATEGIES['local']:
97
            if dtype is None:
Ultima's avatar
Ultima committed
98
99
100
                if isinstance(global_data, distributed_data_object):
                    dtype = global_data.dtype
                elif local_data is not None:
Ultima's avatar
Ultima committed
101
                    try:
102
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
103
                    except(AttributeError):
104
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
105
106
107
                else:
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
108

Ultima's avatar
Ultima committed
109
            else:
110
                dtype = np.dtype(dtype)
111
112
113
114
115
        if expensive_checks:
            dtype_list = comm.allgather(dtype)
            if all(x == dtype_list[0] for x in dtype_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The given dtype must be the same on all nodes!"))
Ultima's avatar
Ultima committed
116
        return_dict['dtype'] = dtype
117
118
119

        # Parse the shape
        # Case 1: global-type slicer
120
        if distribution_strategy in STRATEGIES['global']:
Ultima's avatar
Ultima committed
121
122
123
124
125
126
127
128
129
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
130
                    "global_shape nor hdf5 file supplied!"))
Ultima's avatar
Ultima committed
131
132
            if global_shape == ():
                raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
133
                    "ERROR: global_shape == () is not a valid shape!"))
134

135
136
137
138
139
140
141
            if expensive_checks:
                global_shape_list = comm.allgather(global_shape)
                if not all(x == global_shape_list[0]
                           for x in global_shape_list):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The global_shape must be the same on all " +
                        "nodes!"))
Ultima's avatar
Ultima committed
142
143
            return_dict['global_shape'] = global_shape

144
        # Case 2: local-type slicer
Ultima's avatar
Ultima committed
145
146
147
148
        elif distribution_strategy in ['freeform']:
            if isinstance(global_data, distributed_data_object):
                local_shape = global_data.local_shape
            elif local_data is not None and np.isscalar(local_data) == False:
Ultima's avatar
Ultima committed
149
150
151
152
153
154
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
155
                    "local_shape nor global d2o supplied!"))
Ultima's avatar
Ultima committed
156
157
            if local_shape == ():
                raise ValueError(about._errors.cstring(
158
159
                    "ERROR: local_shape == () is not a valid shape!"))

160
161
162
163
164
165
166
167
            if expensive_checks:
                local_shape_list = comm.allgather(local_shape[1:])
                cleared_set = set(local_shape_list)
                cleared_set.discard(())
                if len(cleared_set) > 1:
                    raise ValueError(about._errors.cstring(
                        "ERROR: All but the first entry of local_shape " +
                        "must be the same on all nodes!"))
Ultima's avatar
Ultima committed
168
169
            return_dict['local_shape'] = local_shape

170
        # Add the name of the distributor if needed
171
172
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
173
174

        # close the file-handle
Ultima's avatar
Ultima committed
175
176
177
        if dset is not None:
            f.close()

178
        return return_dict
179

Ultima's avatar
Ultima committed
180
    def hash_arguments(self, distribution_strategy, **kwargs):
181
        kwargs = kwargs.copy()
182

183
184
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
185

186
        if 'global_shape' in kwargs:
Ultima's avatar
Ultima committed
187
            kwargs['global_shape'] = kwargs['global_shape']
188
        if 'local_shape' in kwargs:
189
190
191
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
192

Ultima's avatar
Ultima committed
193
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
194
        kwargs['distribution_strategy'] = distribution_strategy
195

196
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
197

198
    def dictionize_np(self, x):
199
        dic = x.type.__dict__.items()
200
        if x is np.float:
201
            dic[24] = 0
202
203
            dic[29] = 0
            dic[37] = 0
204
205
        return frozenset(dic)

206
    def get_distributor(self, distribution_strategy, comm, **kwargs):
207
        # check if the distribution strategy is known
208
        if distribution_strategy not in STRATEGIES['all']:
209
            raise ValueError(about._errors.cstring(
210
                "ERROR: Unknown distribution strategy supplied."))
211
212

        # parse the kwargs
Ultima's avatar
Ultima committed
213
        parsed_kwargs = self.parse_kwargs(
214
215
216
            distribution_strategy=distribution_strategy,
            comm=comm,
            **kwargs)
217

Ultima's avatar
Ultima committed
218
219
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
220
        # check if the distributors has already been produced in the past
221
        if hashed_kwargs in self.distributor_store:
Ultima's avatar
Ultima committed
222
            return self.distributor_store[hashed_kwargs]
223
        else:
224
            # produce new distributor
225
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
226
                produced_distributor = _not_distributor(**parsed_kwargs)
227

228
229
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
230
231
                    slicer=_equal_slicer,
                    **parsed_kwargs)
232

233
234
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
235
236
                    slicer=_fftw_slicer,
                    **parsed_kwargs)
Ultima's avatar
Ultima committed
237
238
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
239
240
                    slicer=_freeform_slicer,
                    **parsed_kwargs)
241
242

            self.distributor_store[hashed_kwargs] = produced_distributor
Ultima's avatar
Ultima committed
243
            return self.distributor_store[hashed_kwargs]
244
245


246
distributor_factory = _distributor_factory()
Ultima's avatar
Ultima committed
247
248
249
250
251
252


def _infer_key_type(key):
    if key is None:
        return (None, None)
    found_boolean = False
253
    # Check which case we got:
254
    if isinstance(key, tuple) or isinstance(key, slice) or np.isscalar(key):
255
256
        # Check if there is something different in the array than
        # scalars and slices
Ultima's avatar
Ultima committed
257
258
        if isinstance(key, slice) or np.isscalar(key):
            key = [key]
259

Ultima's avatar
Ultima committed
260
261
262
        scalarQ = np.array(map(np.isscalar, key))
        sliceQ = np.array(map(lambda z: isinstance(z, slice), key))
        if np.all(scalarQ + sliceQ):
263
            found = 'slicetuple'
Ultima's avatar
Ultima committed
264
265
266
267
268
269
270
271
272
273
        else:
            found = 'indexinglist'
    elif isinstance(key, np.ndarray):
        found = 'ndarray'
        found_boolean = (key.dtype.type == np.bool_)
    elif isinstance(key, distributed_data_object):
        found = 'd2o'
        found_boolean = (key.dtype == np.bool_)
    elif isinstance(key, list):
        found = 'indexinglist'
274
275
    else:
        raise ValueError(about._errors.cstring("ERROR: Unknown keytype!"))
Ultima's avatar
Ultima committed
276
277
278
279
    return (found, found_boolean)


class distributor(object):
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    #def inject(self, data, to_slices, data_update, from_slices,
    #           **kwargs):
    #    # check if to_key and from_key are completely built of slices
    #    if not np.all(
    #            np.vectorize(lambda x: isinstance(x, slice))(to_slices)):
    #        raise ValueError(about._errors.cstring(
    #            "ERROR: The to_slices argument must be a list or " +
    #            "tuple of slices!")
    #        )

    #    if not np.all(
    #            np.vectorize(lambda x: isinstance(x, slice))(from_slices)):
    #        raise ValueError(about._errors.cstring(
    #            "ERROR: The from_slices argument must be a list or " +
    #            "tuple of slices!")
    #        )

    #    to_slices = tuple(to_slices)
    #    from_slices = tuple(from_slices)
    #    self.disperse_data(data=data,
    #                       to_key=to_slices,
    #                       data_update=data_update,
    #                       from_key=from_slices,
    #                       **kwargs)
Ultima's avatar
Ultima committed
305

306
    def disperse_data(self, data, to_key, data_update, from_key=None,
Ultima's avatar
Ultima committed
307
                      local_keys=False, copy=True, **kwargs):
308
        # Check which keys we got:
Ultima's avatar
Ultima committed
309
310
311
        (to_found, to_found_boolean) = _infer_key_type(to_key)
        (from_found, from_found_boolean) = _infer_key_type(from_key)

312
        comm = self.comm
313
314
315
316
317
318
319
320
321
322
323
324
        if local_keys is False:
            return self._disperse_data_primitive(
                                         data=data,
                                         to_key=to_key,
                                         data_update=data_update,
                                         from_key=from_key,
                                         copy=copy,
                                         to_found=to_found,
                                         to_found_boolean=to_found_boolean,
                                         from_found=from_found,
                                         from_found_boolean=from_found_boolean,
                                         **kwargs)
325

Ultima's avatar
Ultima committed
326
        else:
327
            # assert that all to_keys are from same type
Ultima's avatar
Ultima committed
328
            to_found_list = comm.allgather(to_found)
329
            assert(all(x == to_found_list[0] for x in to_found_list))
Ultima's avatar
Ultima committed
330
            to_found_boolean_list = comm.allgather(to_found_boolean)
331
332
            assert(all(x == to_found_boolean_list[0] for x in
                       to_found_boolean_list))
Ultima's avatar
Ultima committed
333
            from_found_list = comm.allgather(from_found)
334
            assert(all(x == from_found_list[0] for x in from_found_list))
Ultima's avatar
Ultima committed
335
            from_found_boolean_list = comm.allgather(from_found_boolean)
336
337
            assert(all(x == from_found_boolean_list[0] for
                       x in from_found_boolean_list))
Ultima's avatar
Ultima committed
338

339
340
341
            # gather the local to_keys into a global to_key_list
            # Case 1: the to_keys are not distributed_data_objects
            # -> allgather does the job
Ultima's avatar
Ultima committed
342
343
            if to_found != 'd2o':
                to_key_list = comm.allgather(to_key)
344
345
346
            # Case 2: if the to_keys are distributed_data_objects, gather
            # the index of the array and build the to_key_list with help
            # from the librarian
Ultima's avatar
Ultima committed
347
348
349
350
            else:
                to_index_list = comm.allgather(to_key.index)
                to_key_list = map(lambda z: d2o_librarian[z], to_index_list)

351
            # gather the local from_keys. It is the same procedure as above
Ultima's avatar
Ultima committed
352
            if from_found != 'd2o':
353
                from_key_list = comm.allgather(from_key)
Ultima's avatar
Ultima committed
354
355
            else:
                from_index_list = comm.allgather(from_key.index)
356
357
                from_key_list = map(lambda z: d2o_librarian[z],
                                    from_index_list)
358

Ultima's avatar
Ultima committed
359
            local_data_update_is_scalar = np.isscalar(data_update)
360
            local_scalar_list = comm.allgather(local_data_update_is_scalar)
Ultima's avatar
Ultima committed
361
            for i in xrange(len(to_key_list)):
362
                if np.all(np.array(local_scalar_list) == True):
Ultima's avatar
Ultima committed
363
364
365
366
367
368
                    scalar_list = comm.allgather(data_update)
                    temp_data_update = scalar_list[i]
                elif isinstance(data_update, distributed_data_object):
                    data_update_index_list = comm.allgather(data_update.index)
                    data_update_list = map(lambda z: d2o_librarian[z],
                                           data_update_index_list)
369
                    temp_data_update = data_update_list[i]
Ultima's avatar
Ultima committed
370
                else:
371
372
                    # build a temporary freeform d2o which only contains data
                    # from node i
Ultima's avatar
Ultima committed
373
374
375
                    if comm.rank == i:
                        temp_shape = np.shape(data_update)
                        try:
376
                            temp_dtype = np.dtype(data_update.dtype)
Ultima's avatar
Ultima committed
377
                        except(TypeError):
378
                            temp_dtype = np.array(data_update).dtype
Ultima's avatar
Ultima committed
379
380
381
382
383
                    else:
                        temp_shape = None
                        temp_dtype = None
                    temp_shape = comm.bcast(temp_shape, root=i)
                    temp_dtype = comm.bcast(temp_dtype, root=i)
384

Ultima's avatar
Ultima committed
385
386
387
388
                    if comm.rank != i:
                        temp_shape = list(temp_shape)
                        temp_shape[0] = 0
                        temp_shape = tuple(temp_shape)
389
                        temp_data = np.empty(temp_shape, dtype=temp_dtype)
Ultima's avatar
Ultima committed
390
391
392
                    else:
                        temp_data = data_update
                    temp_data_update = distributed_data_object(
393
394
395
396
                                        local_data=temp_data,
                                        distribution_strategy='freeform',
                                        copy=False,
                                        comm=self.comm)
Ultima's avatar
Ultima committed
397
                # disperse the data one after another
398
399
400
401
402
403
404
405
406
407
408
                self._disperse_data_primitive(
                                      data=data,
                                      to_key=to_key_list[i],
                                      data_update=temp_data_update,
                                      from_key=from_key_list[i],
                                      copy=copy,
                                      to_found=to_found,
                                      to_found_boolean=to_found_boolean,
                                      from_found=from_found,
                                      from_found_boolean=from_found_boolean,
                                      **kwargs)
409
410
                i += 1

411

Ultima's avatar
Ultima committed
412
class _slicing_distributor(distributor):
413
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
414

Ultima's avatar
Ultima committed
415
416
        self.comm = comm
        self.distribution_strategy = name
417
        self.dtype = np.dtype(dtype)
418

theos's avatar
theos committed
419
        self._my_dtype_converter = dtype_converter
420

ultimanet's avatar
ultimanet committed
421
        if not self._my_dtype_converter.known_np_Q(self.dtype):
422
            raise TypeError(about._errors.cstring(
423
424
                "ERROR: The datatype " + str(self.dtype.__repr__()) +
                " is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
425

426
        self.mpi_dtype = self._my_dtype_converter.to_mpi(self.dtype)
427
428

        self.slicer = slicer
429
        self._local_size = self.slicer(comm=comm, **remaining_parsed_kwargs)
430
        self.local_start = self._local_size[0]
431
        self.local_end = self._local_size[1]
Ultima's avatar
Ultima committed
432
        self.global_shape = self._local_size[2]
433

434
        self.local_length = self.local_end - self.local_start
ultimanet's avatar
ultimanet committed
435
436
437
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
438
439
        comm.Allgather([np.array(self.local_dim, dtype=np.int), MPI.INT],
                       [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
440
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
441

442
443
444
445
        self.local_slice = np.array([self.local_start, self.local_end,
                                     self.local_length, self.local_dim,
                                     self.local_dim_offset],
                                    dtype=np.int)
446
        # collect all local_slices
447
448
449
        self.all_local_slices = np.empty((comm.size, 5), dtype=np.int)
        comm.Allgather([np.array((self.local_slice,), dtype=np.int), MPI.INT],
                       [self.all_local_slices, MPI.INT])
450
451

    def initialize_data(self, global_data, local_data, alias, path, hermitian,
Ultima's avatar
Ultima committed
452
                        copy, **kwargs):
453
        if 'h5py' in gdi and alias is not None:
454
            local_data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
455
            return (local_data, hermitian)
456
457

        if self.distribution_strategy in ['equal', 'fftw']:
Ultima's avatar
Ultima committed
458
            if np.isscalar(global_data):
459
                local_data = np.empty(self.local_shape, dtype=self.dtype)
460
                local_data.fill(global_data)
Ultima's avatar
Ultima committed
461
462
                hermitian = True
            else:
463
464
                local_data = self.distribute_data(data=global_data,
                                                  copy=copy)
Ultima's avatar
Ultima committed
465
        elif self.distribution_strategy in ['freeform']:
Ultima's avatar
Ultima committed
466
467
468
            if isinstance(global_data, distributed_data_object):
                local_data = global_data.get_local_data()
            elif np.isscalar(local_data):
469
                temp_local_data = np.empty(self.local_shape,
470
                                           dtype=self.dtype)
471
                temp_local_data.fill(local_data)
472
                local_data = temp_local_data
Ultima's avatar
Ultima committed
473
474
                hermitian = True
            elif local_data is None:
475
                local_data = np.empty(self.local_shape, dtype=self.dtype)
476
477
478
            elif isinstance(local_data, np.ndarray):
                local_data = local_data.astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
479
480
            else:
                local_data = np.array(local_data).astype(
481
                    self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
482
483
        else:
            raise TypeError(about._errors.cstring(
484
                "ERROR: Unknown istribution strategy"))
485
486
        return (local_data, hermitian)

487
    def globalize_flat_index(self, index):
488
        return int(index) + self.local_dim_offset
489

490
491
492
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
493
            raise TypeError(about._errors.cstring("ERROR: Length\
494
                of index tuple does not match the array's shape!"))
495
496
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
497
        # ensure that the globalized index list is within the bounds
498
        global_index_memory = globalized_index
499
        globalized_index = np.clip(globalized_index,
500
                                   -np.array(self.global_shape),
501
                                   np.array(self.global_shape) - 1)
502
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
503
            about.warnings.cprint("WARNING: Indices were clipped!")
504
505
        globalized_index = tuple(globalized_index)
        return globalized_index
506

507
    def _allgather(self, thing, comm=None):
Ultima's avatar
Ultima committed
508
        if comm is None:
509
            comm = self.comm
510
511
        gathered_things = comm.allgather(thing)
        return gathered_things
512

Ultima's avatar
Ultima committed
513
514
515
516
517
518
519
520
    def _Allreduce_sum(self, sendbuf, recvbuf):
        send_dtype = self._my_dtype_converter.to_mpi(sendbuf.dtype)
        recv_dtype = self._my_dtype_converter.to_mpi(recvbuf.dtype)
        self.comm.Allreduce([sendbuf, send_dtype],
                            [recvbuf, recv_dtype],
                            op=MPI.SUM)
        return recvbuf

521
    def distribute_data(self, data=None, alias=None,
522
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
523
        '''
524
        distribute data checks
ultimanet's avatar
ultimanet committed
525
526
527
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
528
529
530

        comm = self.comm

531
        if 'h5py' in gdi and alias is not None:
532
            data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
533
534
535

        local_data_available_Q = (data is not None)
        data_available_Q = np.array(comm.allgather(local_data_available_Q))
536

Ultima's avatar
Ultima committed
537
        if np.all(data_available_Q == False):
Ultimanet's avatar
Ultimanet committed
538
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
539
540
        # if all nodes got data, we assume that it is the right data and
        # store it individually.
Ultima's avatar
Ultima committed
541
542
        elif np.all(data_available_Q == True):
            if isinstance(data, distributed_data_object):
543
                temp_d2o = data.get_data((slice(self.local_start,
Ultima's avatar
Ultima committed
544
                                                self.local_end),),
545
546
547
548
                                         local_keys=True,
                                         copy=copy)
                return temp_d2o.get_local_data(copy=False).astype(self.dtype,
                                                                  copy=False)
549
            else:
Ultima's avatar
Ultima committed
550
                return data[self.local_start:self.local_end].astype(
551
552
                    self.dtype,
                    copy=copy)
ultimanet's avatar
ultimanet committed
553
        else:
Ultima's avatar
Ultima committed
554
555
            raise ValueError(
                "ERROR: distribute_data must get data on all nodes!")
556
557

    def _disperse_data_primitive(self, data, to_key, data_update, from_key,
558
559
                                 copy, to_found, to_found_boolean, from_found,
                                 from_found_boolean, **kwargs):
Ultima's avatar
Ultima committed
560
561
        if np.isscalar(data_update):
            from_key = None
562
563
564

        # Case 1: to_key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
565
566
        if to_found == 'slicetuple':
            if from_found == 'slicetuple':
567
568
569
570
571
572
                return self.disperse_data_to_slices(data=data,
                                                    to_slices=to_key,
                                                    data_update=data_update,
                                                    from_slices=from_key,
                                                    copy=copy,
                                                    **kwargs)
573
574
575
            else:
                if from_key is not None:
                    about.infos.cprint(
576
                        "INFO: Advanced injection is not available for this " +
577
                        "combination of to_key and from_key.")
578
579
580
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
581

582
583
584
585
586
                return self.disperse_data_to_slices(
                                            data=data,
                                            to_slices=to_key,
                                            data_update=prepared_data_update,
                                            copy=copy,
587
                                            **kwargs)
588
589

        # Case 2: key is an array
590
        elif (to_found == 'ndarray' or to_found == 'd2o'):
591
            # Case 2.1: The array is boolean.
592
            if to_found_boolean:
593
594
                if from_key is not None:
                    about.infos.cprint(
595
                        "INFO: Advanced injection is not available for this " +
596
                        "combination of to_key and from_key.")
597
598
599
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
600
601
602
603
604
605
                return self.disperse_data_to_bool(
                                              data=data,
                                              to_boolean_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
606
607
            # Case 2.2: The array is not boolean. Only 1-dimensional
            # advanced slicing is supported.
608
609
610
611
612
            else:
                if len(to_key.shape) != 1:
                    raise ValueError(about._errors.cstring(
                        "WARNING: Only one-dimensional advanced indexing " +
                        "is supported"))
613
                # Make a recursive call in order to trigger the 'list'-section
614
615
616
                return self.disperse_data(data=data, to_key=[to_key],
                                          data_update=data_update,
                                          from_key=from_key, copy=copy,
617
618
                                          **kwargs)

619
620
        # Case 3 : to_key is a list. This list is interpreted as
        # one-dimensional advanced indexing list.
621
622
623
        elif to_found == 'indexinglist':
            if from_key is not None:
                about.infos.cprint(
624
                    "INFO: Advanced injection is not available for this " +
625
                    "combination of to_key and from_key.")
626
627
628
                prepared_data_update = data_update[from_key]
            else:
                prepared_data_update = data_update
629
630
631
632
633
            return self.disperse_data_to_list(data=data,
                                              to_list_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
634
635

    def disperse_data_to_list(self, data, to_list_key, data_update,
636
                              copy=True, **kwargs):
637

638
639
        if to_list_key == []:
            return data
640

Ultima's avatar
Ultima committed
641
        local_to_list_key = self._advanced_index_decycler(to_list_key)
642
        return self._disperse_data_to_list_and_bool_helper(
643
644
645
646
647
            data=data,
            local_to_key=local_to_list_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
648
649

    def disperse_data_to_bool(self, data, to_boolean_key, data_update,
650
                              copy=True, **kwargs):
651
652
        # Extract the part of the to_boolean_key which corresponds to the
        # local data
653
654
        local_to_boolean_key = self.extract_local_data(to_boolean_key)
        return self._disperse_data_to_list_and_bool_helper(
655
656
657
658
659
            data=data,
            local_to_key=local_to_boolean_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
660

661
    def _disperse_data_to_list_and_bool_helper(self, data, local_to_key,
662
                                               data_update, copy, **kwargs):
663
664
        comm = self.comm
        rank = comm.rank
665
666
        # Infer the length and offset of the locally affected data
        locally_affected_data = data[local_to_key]
667
        data_length = np.shape(locally_affected_data)[0]
Ultima's avatar
Ultima committed
668
        data_length_list = comm.allgather(data_length)
669
670
671
        data_length_offset_list = np.append([0],
                                            np.cumsum(data_length_list)[:-1])

672
673
        # Update the local data object with its very own portion
        o = data_length_offset_list
674
        l = data_length
675

676
        if isinstance(data_update, distributed_data_object):
677
            local_data_update = data_update.get_data(
678
679
                                          slice(o[rank], o[rank] + l),
                                          local_keys=True
680
681
682
                                          ).get_local_data(copy=False)
            data[local_to_key] = local_data_update.astype(self.dtype,
                                                          copy=False)
Ultima's avatar
Ultima committed
683
684
685
        elif np.isscalar(data_update):
            data[local_to_key] = data_update
        else:
686
            data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
687
688
                                          copy=copy).astype(self.dtype,
                                                            copy=False)
689
690
691
        return data

    def disperse_data_to_slices(self, data, to_slices,
692
                                data_update, from_slices=None, copy=True):
693
694
695
696
        comm = self.comm
        (to_slices, sliceified) = self._sliceify(to_slices)

        # parse the to_slices object
697
698
699
700
701
        localized_to_start, localized_to_stop = self._backshift_and_decycle(
            to_slices[0], self.local_start, self.local_end,
            self.global_shape[0])
        local_to_slice = (slice(localized_to_start, localized_to_stop,
                                to_slices[0].step),) + to_slices[1:]
702
        local_to_slice_shape = data[local_to_slice].shape
703

Ultima's avatar
Ultima committed
704
705
706
        to_step = to_slices[0].step
        if to_step is None:
            to_step = 1
707
        elif to_step == 0:
708
            raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
709
710
                "ERROR: to_step size == 0!"))

711
712
713
714
        # Compute the offset of the data the individual node will take.
        # The offset is free of stepsizes. It is the offset in terms of
        # the purely transported data. If to_step < 0, the offset will
        # be calculated in reverse order
Ultima's avatar
Ultima committed
715
        order = np.sign(to_step)
716

Ultima's avatar
Ultima committed
717
        local_affected_data_length = local_to_slice_shape[0]
718
719
720
        local_affected_data_length_list = np.empty(comm.size, dtype=np.int)
        comm.Allgather(
            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],
721
            [local_affected_data_length_list, MPI.INT])
722
723
724
        local_affected_data_length_offset_list = np.append([0],
                                                           np.cumsum(
            local_affected_data_length_list[::order])[:-1])[::order]
Ultima's avatar
Ultima committed
725

726
        if np.isscalar(data_update):
Ultima's avatar
Ultima committed
727
728
            data[local_to_slice] = data_update
        else:
729
            # construct the locally adapted from_slice object
Ultima's avatar
Ultima committed
730
731
732
733
            r = comm.rank
            o = local_affected_data_length_offset_list
            l = local_affected_data_length

734
735
736
            data_update = self._enfold(data_update, sliceified)

            # parse the from_slices object
Ultima's avatar
Ultima committed
737
            if from_slices is None:
738
                from_slices = (slice(None, None, None),)
739
740
741
742
743
744
            (from_slices_start, from_slices_stop) = \
                self._backshift_and_decycle(
                                            slice_object=from_slices[0],
                                            shifted_start=0,
                                            shifted_stop=data_update.shape[0],
                                            global_length=data_update.shape[0])
Ultima's avatar
Ultima committed
745
            if from_slices_start is None:
746
747
748
                raise ValueError(about._errors.cstring(
                    "ERROR: _backshift_and_decycle should never return " +
                    "None for local_start!"))
749
750

            # parse the step sizes
751
            from_step = from_slices[0].step
Ultima's avatar
Ultima committed
752
            if from_step is None:
753
                from_step = 1
754
            elif from_step == 0:
755
                raise ValueError(about._errors.cstring(
756
                    "ERROR: from_step size == 0!"))
757

758
            localized_from_start = from_slices_start + from_step * o[r]
759
            localized_from_stop = localized_from_start + from_step * l
760
761
            if localized_from_stop < 0:
                localized_from_stop = None
762
763

            localized_from_slice = (slice(localized_from_start,
764
765
                                          localized_from_stop,
                                          from_step),)
766

767
            update_slice = localized_from_slice + from_slices[1:]
768
769

            if isinstance(data_update, distributed_data_object):
770
                selected_update = data_update.get_data(
771
                                 key=update_slice,
772
773
774
775
                                 local_keys=True)
                local_data_update = selected_update.get_local_data(copy=False)
                local_data_update = local_data_update.astype(self.dtype,
                                                             copy=False)
Ultima's avatar
Ultima committed
776
777
                if np.prod(np.shape(local_data_update)) != 0:
                    data[local_to_slice] = local_data_update
778
            # elif np.isscalar(data_update):
Ultima's avatar
Ultima committed
779
            #    data[local_to_slice] = data_update
780
781
            else:
                local_data_update = np.array(data_update)[update_slice]
Ultima's avatar
Ultima committed
782
                if np.prod(np.shape(local_data_update)) != 0:
783
784
                    data[local_to_slice] = np.array(
                                                local_data_update,
785
786
                                                copy=copy).astype(self.dtype,
                                                                  copy=False)
787

788
    def collect_data(self, data, key, local_keys=False, copy=True, **kwargs):
789
790
791
792
793
794
795
796
797
798
        # collect_data supports three types of keys
        # Case 1: key is a slicing/index tuple
        # Case 2: key is a boolean-array of the same shape as self
        # Case 3: key is a list of shape (n,), where n is
        #         0<n<len(self.shape). The entries of the list must be a
        #         scalar/list/tuple/ndarray. If not scalar the length must be
        #         the same for all of the lists. This is essentially
        #         numpy advanced indexing in one dimension, only.

        # Check which case we got:
Ultima's avatar
Ultima committed
799
        (found, found_boolean) = _infer_key_type(key)
800
        comm = self.comm
801
        if local_keys is False:
802
            return self._collect_data_primitive(data, key, found,
803
804
                                                found_boolean, copy=copy,
                                                **kwargs)
805
        else:
806
            # assert that all keys are from same type
807
            found_list = comm.allgather(found)
808
            assert(all(x == found_list[0] for x in found_list))
809
            found_boolean_list = comm.allgather(found_boolean)
810
811
812
813
814
            assert(all(x == found_boolean_list[0] for x in found_boolean_list))

            # gather the local_keys into a global key_list
            # Case 1: the keys are no distributed_data_objects
            # -> allgather does the job
815
816
            if found != 'd2o':
                key_list = comm.allgather(key)
817
818
819
            # Case 2: if the keys are distributed_data_objects, gather
            # the index of the array and build the key_list with help
            # from the librarian
820
821
822
823
824
            else:
                index_list = comm.allgather(key.index)
                key_list = map(lambda z: d2o_librarian[z], index_list)
            i = 0
            for temp_key in key_list:
825
                # build the locally fed d2o
826
                temp_d2o = self._collect_data_primitive(data, temp_key, found,
827
                                                        found_boolean,
828
                                                        copy=copy, **kwargs)
829
830
                # collect the data stored in the d2o to the individual target
                # rank
831
                temp_data = temp_d2o.get_full_data(target_rank=i)
832
833
834
835
                if comm.rank == i:
                    individual_data = temp_data
                i += 1
            return_d2o = distributed_data_object(
836
837
838
839
                            local_data=individual_data,
                            distribution_strategy='freeform',
                            copy=False,
                            comm=self.comm)
840
            return return_d2o
841

842
    def _collect_data_primitive(self, data, key, found, found_boolean,
843
                                copy=True, **kwargs):
844
845
846

        # Case 1: key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
847
        if found == 'slicetuple':
848
849
            return self.collect_data_from_slices(data=data,
                                                 slice_objects=key,
850
                                                 copy=copy,
851
852
                                                 **kwargs)
        # Case 2: key is an array
Ultima's avatar
Ultima committed
853
        elif (found == 'ndarray' or found == 'd2o'):