distributor_factory.py 78.4 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2

ultimanet's avatar
ultimanet committed
3
import numpy as np
4

theos's avatar
theos committed
5
6
7
from nifty.keepers import about,\
                          global_configuration as gc,\
                          global_dependency_injector as gdi
8

theos's avatar
theos committed
9
from distributed_data_object import distributed_data_object
10

theos's avatar
theos committed
11
12
13
14
from d2o_iter import d2o_slicing_iter,\
                     d2o_not_iter
from d2o_librarian import d2o_librarian
from dtype_converter import dtype_converter
15

theos's avatar
theos committed
16
from strategies import STRATEGIES
17

theos's avatar
theos committed
18
19
20
MPI = gdi[gc['mpi_module']]
h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
ultimanet's avatar
ultimanet committed
21

22

23
class _distributor_factory(object):
24

25
26
    def __init__(self):
        self.distributor_store = {}
27

28
    def parse_kwargs(self, distribution_strategy, comm,
29
30
31
                     global_data=None, global_shape=None,
                     local_data=None, local_shape=None,
                     alias=None, path=None,
32
33
34
35
36
37
38
39
40
41
42
43
                     dtype=None, skip_parsing=False, **kwargs):

        if skip_parsing:
            return_dict = {'comm': comm,
                           'dtype': dtype,
                           'name': distribution_strategy
                           }
            if distribution_strategy in STRATEGIES['global']:
                return_dict['global_shape'] = global_shape
            elif distribution_strategy in STRATEGIES['local']:
                return_dict['local_shape'] = local_shape
            return return_dict
Ultima's avatar
Ultima committed
44

45
        return_dict = {}
46
47
48
49
50

        expensive_checks = gc['d2o_init_checks']

        # Parse the MPI communicator
        if comm is None:
Ultima's avatar
Ultima committed
51
            raise ValueError(about._errors.cstring(
52
53
54
55
56
57
58
59
60
61
62
                "ERROR: The distributor needs MPI-communicator object comm!"))
        else:
            return_dict['comm'] = comm

        if expensive_checks:
            # Check that all nodes got the same distribution_strategy
            strat_list = comm.allgather(distribution_strategy)
            if all(x == strat_list[0] for x in strat_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The distribution-strategy must be the same on " +
                    "all nodes!"))
63

64
        # Check for an hdf5 file and open it if given
65
        if 'h5py' in gdi and alias is not None:
66
67
68
            # set file path
            file_path = path if (path is not None) else alias
            # open hdf5 file
69
            if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
Ultima's avatar
Ultima committed
70
71
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
72
73
74
                f = h5py.File(file_path, 'r')
            # open alias in file
            dset = f[alias]
Ultima's avatar
Ultima committed
75
76
77
        else:
            dset = None

78
        # Parse the datatype
Ultima's avatar
Ultima committed
79
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
80
                (dset is not None):
81
            dtype = dset.dtype
82
83

        elif distribution_strategy in ['not', 'equal', 'fftw']:
Ultima's avatar
Ultima committed
84
            if dtype is None:
85
                if global_data is None:
Ultima's avatar
Ultima committed
86
87
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
88
                else:
Ultima's avatar
Ultima committed
89
                    try:
90
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
91
                    except(AttributeError):
92
93
94
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
95

96
        elif distribution_strategy in STRATEGIES['local']:
97
            if dtype is None:
Ultima's avatar
Ultima committed
98
99
100
                if isinstance(global_data, distributed_data_object):
                    dtype = global_data.dtype
                elif local_data is not None:
Ultima's avatar
Ultima committed
101
                    try:
102
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
103
                    except(AttributeError):
104
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
105
106
107
                else:
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
108

Ultima's avatar
Ultima committed
109
            else:
110
                dtype = np.dtype(dtype)
111
112
113
114
115
        if expensive_checks:
            dtype_list = comm.allgather(dtype)
            if all(x == dtype_list[0] for x in dtype_list) == False:
                raise ValueError(about._errors.cstring(
                    "ERROR: The given dtype must be the same on all nodes!"))
Ultima's avatar
Ultima committed
116
        return_dict['dtype'] = dtype
117
118
119

        # Parse the shape
        # Case 1: global-type slicer
120
        if distribution_strategy in STRATEGIES['global']:
Ultima's avatar
Ultima committed
121
122
123
124
125
126
127
128
129
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
130
                    "global_shape nor hdf5 file supplied!"))
Ultima's avatar
Ultima committed
131
132
            if global_shape == ():
                raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
133
                    "ERROR: global_shape == () is not a valid shape!"))
134

135
136
137
138
139
140
141
            if expensive_checks:
                global_shape_list = comm.allgather(global_shape)
                if not all(x == global_shape_list[0]
                           for x in global_shape_list):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The global_shape must be the same on all " +
                        "nodes!"))
Ultima's avatar
Ultima committed
142
143
            return_dict['global_shape'] = global_shape

144
        # Case 2: local-type slicer
Ultima's avatar
Ultima committed
145
146
147
148
        elif distribution_strategy in ['freeform']:
            if isinstance(global_data, distributed_data_object):
                local_shape = global_data.local_shape
            elif local_data is not None and np.isscalar(local_data) == False:
Ultima's avatar
Ultima committed
149
150
151
152
153
154
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
155
                    "local_shape nor global d2o supplied!"))
Ultima's avatar
Ultima committed
156
157
            if local_shape == ():
                raise ValueError(about._errors.cstring(
158
159
                    "ERROR: local_shape == () is not a valid shape!"))

160
161
162
163
164
165
166
167
            if expensive_checks:
                local_shape_list = comm.allgather(local_shape[1:])
                cleared_set = set(local_shape_list)
                cleared_set.discard(())
                if len(cleared_set) > 1:
                    raise ValueError(about._errors.cstring(
                        "ERROR: All but the first entry of local_shape " +
                        "must be the same on all nodes!"))
Ultima's avatar
Ultima committed
168
169
            return_dict['local_shape'] = local_shape

170
        # Add the name of the distributor if needed
171
172
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
173
174

        # close the file-handle
Ultima's avatar
Ultima committed
175
176
177
        if dset is not None:
            f.close()

178
        return return_dict
179

Ultima's avatar
Ultima committed
180
    def hash_arguments(self, distribution_strategy, **kwargs):
181
        kwargs = kwargs.copy()
182

183
184
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
185

186
        if 'global_shape' in kwargs:
Ultima's avatar
Ultima committed
187
            kwargs['global_shape'] = kwargs['global_shape']
188
        if 'local_shape' in kwargs:
189
190
191
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
192

Ultima's avatar
Ultima committed
193
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
194
        kwargs['distribution_strategy'] = distribution_strategy
195

196
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
197

198
    def dictionize_np(self, x):
199
        dic = x.type.__dict__.items()
200
        if x is np.float:
201
            dic[24] = 0
202
203
            dic[29] = 0
            dic[37] = 0
204
205
        return frozenset(dic)

206
    def get_distributor(self, distribution_strategy, comm, **kwargs):
207
        # check if the distribution strategy is known
208
        if distribution_strategy not in STRATEGIES['all']:
209
            raise ValueError(about._errors.cstring(
210
                "ERROR: Unknown distribution strategy supplied."))
211
212

        # parse the kwargs
Ultima's avatar
Ultima committed
213
        parsed_kwargs = self.parse_kwargs(
214
215
216
            distribution_strategy=distribution_strategy,
            comm=comm,
            **kwargs)
217

Ultima's avatar
Ultima committed
218
219
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
220
        # check if the distributors has already been produced in the past
221
        if hashed_kwargs in self.distributor_store:
Ultima's avatar
Ultima committed
222
            return self.distributor_store[hashed_kwargs]
223
        else:
224
            # produce new distributor
225
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
226
                produced_distributor = _not_distributor(**parsed_kwargs)
227

228
229
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
230
231
                    slicer=_equal_slicer,
                    **parsed_kwargs)
232

233
234
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
235
236
                    slicer=_fftw_slicer,
                    **parsed_kwargs)
Ultima's avatar
Ultima committed
237
238
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
239
240
                    slicer=_freeform_slicer,
                    **parsed_kwargs)
241
242

            self.distributor_store[hashed_kwargs] = produced_distributor
Ultima's avatar
Ultima committed
243
            return self.distributor_store[hashed_kwargs]
244
245


246
distributor_factory = _distributor_factory()
Ultima's avatar
Ultima committed
247
248
249
250
251
252


def _infer_key_type(key):
    if key is None:
        return (None, None)
    found_boolean = False
253
    # Check which case we got:
254
    if isinstance(key, tuple) or isinstance(key, slice) or np.isscalar(key):
255
256
        # Check if there is something different in the array than
        # scalars and slices
Ultima's avatar
Ultima committed
257
258
        if isinstance(key, slice) or np.isscalar(key):
            key = [key]
259

Ultima's avatar
Ultima committed
260
261
262
        scalarQ = np.array(map(np.isscalar, key))
        sliceQ = np.array(map(lambda z: isinstance(z, slice), key))
        if np.all(scalarQ + sliceQ):
263
            found = 'slicetuple'
Ultima's avatar
Ultima committed
264
265
266
267
268
269
270
271
272
273
        else:
            found = 'indexinglist'
    elif isinstance(key, np.ndarray):
        found = 'ndarray'
        found_boolean = (key.dtype.type == np.bool_)
    elif isinstance(key, distributed_data_object):
        found = 'd2o'
        found_boolean = (key.dtype == np.bool_)
    elif isinstance(key, list):
        found = 'indexinglist'
274
275
    else:
        raise ValueError(about._errors.cstring("ERROR: Unknown keytype!"))
Ultima's avatar
Ultima committed
276
277
278
279
    return (found, found_boolean)


class distributor(object):
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    #def inject(self, data, to_slices, data_update, from_slices,
    #           **kwargs):
    #    # check if to_key and from_key are completely built of slices
    #    if not np.all(
    #            np.vectorize(lambda x: isinstance(x, slice))(to_slices)):
    #        raise ValueError(about._errors.cstring(
    #            "ERROR: The to_slices argument must be a list or " +
    #            "tuple of slices!")
    #        )

    #    if not np.all(
    #            np.vectorize(lambda x: isinstance(x, slice))(from_slices)):
    #        raise ValueError(about._errors.cstring(
    #            "ERROR: The from_slices argument must be a list or " +
    #            "tuple of slices!")
    #        )

    #    to_slices = tuple(to_slices)
    #    from_slices = tuple(from_slices)
    #    self.disperse_data(data=data,
    #                       to_key=to_slices,
    #                       data_update=data_update,
    #                       from_key=from_slices,
    #                       **kwargs)
Ultima's avatar
Ultima committed
305

306
    def disperse_data(self, data, to_key, data_update, from_key=None,
Ultima's avatar
Ultima committed
307
                      local_keys=False, copy=True, **kwargs):
308
        # Check which keys we got:
Ultima's avatar
Ultima committed
309
310
311
        (to_found, to_found_boolean) = _infer_key_type(to_key)
        (from_found, from_found_boolean) = _infer_key_type(from_key)

312
        comm = self.comm
313
314
315
316
317
318
319
320
321
322
323
324
        if local_keys is False:
            return self._disperse_data_primitive(
                                         data=data,
                                         to_key=to_key,
                                         data_update=data_update,
                                         from_key=from_key,
                                         copy=copy,
                                         to_found=to_found,
                                         to_found_boolean=to_found_boolean,
                                         from_found=from_found,
                                         from_found_boolean=from_found_boolean,
                                         **kwargs)
325

Ultima's avatar
Ultima committed
326
        else:
327
            # assert that all to_keys are from same type
Ultima's avatar
Ultima committed
328
            to_found_list = comm.allgather(to_found)
329
            assert(all(x == to_found_list[0] for x in to_found_list))
Ultima's avatar
Ultima committed
330
            to_found_boolean_list = comm.allgather(to_found_boolean)
331
332
            assert(all(x == to_found_boolean_list[0] for x in
                       to_found_boolean_list))
Ultima's avatar
Ultima committed
333
            from_found_list = comm.allgather(from_found)
334
            assert(all(x == from_found_list[0] for x in from_found_list))
Ultima's avatar
Ultima committed
335
            from_found_boolean_list = comm.allgather(from_found_boolean)
336
337
            assert(all(x == from_found_boolean_list[0] for
                       x in from_found_boolean_list))
Ultima's avatar
Ultima committed
338

339
340
341
            # gather the local to_keys into a global to_key_list
            # Case 1: the to_keys are not distributed_data_objects
            # -> allgather does the job
Ultima's avatar
Ultima committed
342
343
            if to_found != 'd2o':
                to_key_list = comm.allgather(to_key)
344
345
346
            # Case 2: if the to_keys are distributed_data_objects, gather
            # the index of the array and build the to_key_list with help
            # from the librarian
Ultima's avatar
Ultima committed
347
348
349
350
            else:
                to_index_list = comm.allgather(to_key.index)
                to_key_list = map(lambda z: d2o_librarian[z], to_index_list)

351
            # gather the local from_keys. It is the same procedure as above
Ultima's avatar
Ultima committed
352
            if from_found != 'd2o':
353
                from_key_list = comm.allgather(from_key)
Ultima's avatar
Ultima committed
354
355
            else:
                from_index_list = comm.allgather(from_key.index)
356
357
                from_key_list = map(lambda z: d2o_librarian[z],
                                    from_index_list)
358

Ultima's avatar
Ultima committed
359
            local_data_update_is_scalar = np.isscalar(data_update)
360
            local_scalar_list = comm.allgather(local_data_update_is_scalar)
Ultima's avatar
Ultima committed
361
            for i in xrange(len(to_key_list)):
362
                if np.all(np.array(local_scalar_list) == True):
Ultima's avatar
Ultima committed
363
364
365
366
367
368
                    scalar_list = comm.allgather(data_update)
                    temp_data_update = scalar_list[i]
                elif isinstance(data_update, distributed_data_object):
                    data_update_index_list = comm.allgather(data_update.index)
                    data_update_list = map(lambda z: d2o_librarian[z],
                                           data_update_index_list)
369
                    temp_data_update = data_update_list[i]
Ultima's avatar
Ultima committed
370
                else:
371
372
                    # build a temporary freeform d2o which only contains data
                    # from node i
Ultima's avatar
Ultima committed
373
374
375
                    if comm.rank == i:
                        temp_shape = np.shape(data_update)
                        try:
376
                            temp_dtype = np.dtype(data_update.dtype)
Ultima's avatar
Ultima committed
377
                        except(TypeError):
378
                            temp_dtype = np.array(data_update).dtype
Ultima's avatar
Ultima committed
379
380
381
382
383
                    else:
                        temp_shape = None
                        temp_dtype = None
                    temp_shape = comm.bcast(temp_shape, root=i)
                    temp_dtype = comm.bcast(temp_dtype, root=i)
384

Ultima's avatar
Ultima committed
385
386
387
388
                    if comm.rank != i:
                        temp_shape = list(temp_shape)
                        temp_shape[0] = 0
                        temp_shape = tuple(temp_shape)
389
                        temp_data = np.empty(temp_shape, dtype=temp_dtype)
Ultima's avatar
Ultima committed
390
391
392
                    else:
                        temp_data = data_update
                    temp_data_update = distributed_data_object(
393
394
395
                        local_data=temp_data,
                        distribution_strategy='freeform',
                        comm=self.comm)
Ultima's avatar
Ultima committed
396
                # disperse the data one after another
397
398
399
400
401
402
403
404
405
406
407
                self._disperse_data_primitive(
                                      data=data,
                                      to_key=to_key_list[i],
                                      data_update=temp_data_update,
                                      from_key=from_key_list[i],
                                      copy=copy,
                                      to_found=to_found,
                                      to_found_boolean=to_found_boolean,
                                      from_found=from_found,
                                      from_found_boolean=from_found_boolean,
                                      **kwargs)
408
409
                i += 1

410

Ultima's avatar
Ultima committed
411
class _slicing_distributor(distributor):
412

413
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
414

Ultima's avatar
Ultima committed
415
416
        self.comm = comm
        self.distribution_strategy = name
417
        self.dtype = np.dtype(dtype)
418

theos's avatar
theos committed
419
        self._my_dtype_converter = dtype_converter
420

ultimanet's avatar
ultimanet committed
421
        if not self._my_dtype_converter.known_np_Q(self.dtype):
422
            raise TypeError(about._errors.cstring(
423
424
                "ERROR: The datatype " + str(self.dtype.__repr__()) +
                " is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
425

426
        self.mpi_dtype = self._my_dtype_converter.to_mpi(self.dtype)
427
428

        self.slicer = slicer
429
        self._local_size = self.slicer(comm=comm, **remaining_parsed_kwargs)
430
        self.local_start = self._local_size[0]
431
        self.local_end = self._local_size[1]
Ultima's avatar
Ultima committed
432
        self.global_shape = self._local_size[2]
433

434
        self.local_length = self.local_end - self.local_start
ultimanet's avatar
ultimanet committed
435
436
437
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
438
439
        comm.Allgather([np.array(self.local_dim, dtype=np.int), MPI.INT],
                       [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
440
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
441

442
443
444
445
        self.local_slice = np.array([self.local_start, self.local_end,
                                     self.local_length, self.local_dim,
                                     self.local_dim_offset],
                                    dtype=np.int)
446
        # collect all local_slices
447
448
449
        self.all_local_slices = np.empty((comm.size, 5), dtype=np.int)
        comm.Allgather([np.array((self.local_slice,), dtype=np.int), MPI.INT],
                       [self.all_local_slices, MPI.INT])
450
451

    def initialize_data(self, global_data, local_data, alias, path, hermitian,
Ultima's avatar
Ultima committed
452
                        copy, **kwargs):
453
        if 'h5py' in gdi and alias is not None:
454
            local_data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
455
            return (local_data, hermitian)
456
457

        if self.distribution_strategy in ['equal', 'fftw']:
Ultima's avatar
Ultima committed
458
            if np.isscalar(global_data):
459
                local_data = np.empty(self.local_shape, dtype=self.dtype)
460
                local_data.fill(global_data)
Ultima's avatar
Ultima committed
461
462
                hermitian = True
            else:
463
464
                local_data = self.distribute_data(data=global_data,
                                                  copy=copy)
Ultima's avatar
Ultima committed
465
        elif self.distribution_strategy in ['freeform']:
Ultima's avatar
Ultima committed
466
467
468
            if isinstance(global_data, distributed_data_object):
                local_data = global_data.get_local_data()
            elif np.isscalar(local_data):
469
                temp_local_data = np.empty(self.local_shape,
470
                                           dtype=self.dtype)
471
                temp_local_data.fill(local_data)
472
                local_data = temp_local_data
Ultima's avatar
Ultima committed
473
474
                hermitian = True
            elif local_data is None:
475
                local_data = np.empty(self.local_shape, dtype=self.dtype)
476
477
478
            elif isinstance(local_data, np.ndarray):
                local_data = local_data.astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
479
480
            else:
                local_data = np.array(local_data).astype(
481
                    self.dtype, copy=copy).reshape(self.local_shape)
Ultima's avatar
Ultima committed
482
483
        else:
            raise TypeError(about._errors.cstring(
484
                "ERROR: Unknown istribution strategy"))
485
486
        return (local_data, hermitian)

487
    def globalize_flat_index(self, index):
488
        return int(index) + self.local_dim_offset
489

490
491
492
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
493
            raise TypeError(about._errors.cstring("ERROR: Length\
494
                of index tuple does not match the array's shape!"))
495
496
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
497
        # ensure that the globalized index list is within the bounds
498
        global_index_memory = globalized_index
499
        globalized_index = np.clip(globalized_index,
500
                                   -np.array(self.global_shape),
501
                                   np.array(self.global_shape) - 1)
502
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
503
            about.warnings.cprint("WARNING: Indices were clipped!")
504
505
        globalized_index = tuple(globalized_index)
        return globalized_index
506

507
    def _allgather(self, thing, comm=None):
Ultima's avatar
Ultima committed
508
        if comm is None:
509
            comm = self.comm
510
511
        gathered_things = comm.allgather(thing)
        return gathered_things
512

Ultima's avatar
Ultima committed
513
514
515
516
517
518
519
520
    def _Allreduce_sum(self, sendbuf, recvbuf):
        send_dtype = self._my_dtype_converter.to_mpi(sendbuf.dtype)
        recv_dtype = self._my_dtype_converter.to_mpi(recvbuf.dtype)
        self.comm.Allreduce([sendbuf, send_dtype],
                            [recvbuf, recv_dtype],
                            op=MPI.SUM)
        return recvbuf

521
    def distribute_data(self, data=None, alias=None,
522
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
523
        '''
524
        distribute data checks
ultimanet's avatar
ultimanet committed
525
526
527
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
528
529
530

        comm = self.comm

531
        if 'h5py' in gdi and alias is not None:
532
            data = self.load_data(alias=alias, path=path)
Ultima's avatar
Ultima committed
533
534
535

        local_data_available_Q = (data is not None)
        data_available_Q = np.array(comm.allgather(local_data_available_Q))
536

Ultima's avatar
Ultima committed
537
        if np.all(data_available_Q == False):
Ultimanet's avatar
Ultimanet committed
538
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
539
540
        # if all nodes got data, we assume that it is the right data and
        # store it individually.
Ultima's avatar
Ultima committed
541
542
        elif np.all(data_available_Q == True):
            if isinstance(data, distributed_data_object):
543
                temp_d2o = data.get_data((slice(self.local_start,
Ultima's avatar
Ultima committed
544
                                                self.local_end),),
545
                                         local_keys=True)
546
                return temp_d2o.get_local_data().astype(self.dtype,
547
                                                        copy=copy)
548
            else:
Ultima's avatar
Ultima committed
549
                return data[self.local_start:self.local_end].astype(
550
551
                    self.dtype,
                    copy=copy)
ultimanet's avatar
ultimanet committed
552
        else:
Ultima's avatar
Ultima committed
553
554
            raise ValueError(
                "ERROR: distribute_data must get data on all nodes!")
555
556

    def _disperse_data_primitive(self, data, to_key, data_update, from_key,
557
558
                                 copy, to_found, to_found_boolean, from_found,
                                 from_found_boolean, **kwargs):
Ultima's avatar
Ultima committed
559
560
        if np.isscalar(data_update):
            from_key = None
561
562
563

        # Case 1: to_key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
564
565
        if to_found == 'slicetuple':
            if from_found == 'slicetuple':
566
567
568
569
570
571
                return self.disperse_data_to_slices(data=data,
                                                    to_slices=to_key,
                                                    data_update=data_update,
                                                    from_slices=from_key,
                                                    copy=copy,
                                                    **kwargs)
572
573
574
            else:
                if from_key is not None:
                    about.infos.cprint(
575
                        "INFO: Advanced injection is not available for this " +
576
                        "combination of to_key and from_key.")
577
578
579
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
580

581
582
583
584
585
                return self.disperse_data_to_slices(
                                            data=data,
                                            to_slices=to_key,
                                            data_update=prepared_data_update,
                                            copy=copy,
586
                                            **kwargs)
587
588

        # Case 2: key is an array
589
        elif (to_found == 'ndarray' or to_found == 'd2o'):
590
            # Case 2.1: The array is boolean.
591
            if to_found_boolean:
592
593
                if from_key is not None:
                    about.infos.cprint(
594
                        "INFO: Advanced injection is not available for this " +
595
                        "combination of to_key and from_key.")
596
597
598
                    prepared_data_update = data_update[from_key]
                else:
                    prepared_data_update = data_update
599
600
601
602
603
604
                return self.disperse_data_to_bool(
                                              data=data,
                                              to_boolean_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
605
606
            # Case 2.2: The array is not boolean. Only 1-dimensional
            # advanced slicing is supported.
607
608
609
610
611
            else:
                if len(to_key.shape) != 1:
                    raise ValueError(about._errors.cstring(
                        "WARNING: Only one-dimensional advanced indexing " +
                        "is supported"))
612
                # Make a recursive call in order to trigger the 'list'-section
613
614
615
                return self.disperse_data(data=data, to_key=[to_key],
                                          data_update=data_update,
                                          from_key=from_key, copy=copy,
616
617
                                          **kwargs)

618
619
        # Case 3 : to_key is a list. This list is interpreted as
        # one-dimensional advanced indexing list.
620
621
622
        elif to_found == 'indexinglist':
            if from_key is not None:
                about.infos.cprint(
623
                    "INFO: Advanced injection is not available for this " +
624
                    "combination of to_key and from_key.")
625
626
627
                prepared_data_update = data_update[from_key]
            else:
                prepared_data_update = data_update
628
629
630
631
632
            return self.disperse_data_to_list(data=data,
                                              to_list_key=to_key,
                                              data_update=prepared_data_update,
                                              copy=copy,
                                              **kwargs)
633
634

    def disperse_data_to_list(self, data, to_list_key, data_update,
635
                              copy=True, **kwargs):
636

637
638
        if to_list_key == []:
            return data
639

Ultima's avatar
Ultima committed
640
        local_to_list_key = self._advanced_index_decycler(to_list_key)
641
        return self._disperse_data_to_list_and_bool_helper(
642
643
644
645
646
            data=data,
            local_to_key=local_to_list_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
647
648

    def disperse_data_to_bool(self, data, to_boolean_key, data_update,
649
                              copy=True, **kwargs):
650
651
        # Extract the part of the to_boolean_key which corresponds to the
        # local data
652
653
        local_to_boolean_key = self.extract_local_data(to_boolean_key)
        return self._disperse_data_to_list_and_bool_helper(
654
655
656
657
658
            data=data,
            local_to_key=local_to_boolean_key,
            data_update=data_update,
            copy=copy,
            **kwargs)
659

660
    def _disperse_data_to_list_and_bool_helper(self, data, local_to_key,
661
                                               data_update, copy, **kwargs):
662
663
        comm = self.comm
        rank = comm.rank
664
665
        # Infer the length and offset of the locally affected data
        locally_affected_data = data[local_to_key]
666
        data_length = np.shape(locally_affected_data)[0]
Ultima's avatar
Ultima committed
667
        data_length_list = comm.allgather(data_length)
668
669
670
        data_length_offset_list = np.append([0],
                                            np.cumsum(data_length_list)[:-1])

671
672
        # Update the local data object with its very own portion
        o = data_length_offset_list
673
        l = data_length
674

675
        if isinstance(data_update, distributed_data_object):
676
677
678
679
            data[local_to_key] = data_update.get_data(
                                          slice(o[rank], o[rank] + l),
                                          local_keys=True
                                          ).get_local_data().astype(self.dtype)
Ultima's avatar
Ultima committed
680
681
682
        elif np.isscalar(data_update):
            data[local_to_key] = data_update
        else:
683
684
            data[local_to_key] = np.array(data_update[o[rank]:o[rank] + l],
                                          copy=copy).astype(self.dtype)
685
686
687
        return data

    def disperse_data_to_slices(self, data, to_slices,
688
                                data_update, from_slices=None, copy=True):
689
690
691
692
        comm = self.comm
        (to_slices, sliceified) = self._sliceify(to_slices)

        # parse the to_slices object
693
694
695
696
697
        localized_to_start, localized_to_stop = self._backshift_and_decycle(
            to_slices[0], self.local_start, self.local_end,
            self.global_shape[0])
        local_to_slice = (slice(localized_to_start, localized_to_stop,
                                to_slices[0].step),) + to_slices[1:]
698
        local_to_slice_shape = data[local_to_slice].shape
699

Ultima's avatar
Ultima committed
700
701
702
        to_step = to_slices[0].step
        if to_step is None:
            to_step = 1
703
        elif to_step == 0:
704
            raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
705
706
                "ERROR: to_step size == 0!"))

707
708
709
710
        # Compute the offset of the data the individual node will take.
        # The offset is free of stepsizes. It is the offset in terms of
        # the purely transported data. If to_step < 0, the offset will
        # be calculated in reverse order
Ultima's avatar
Ultima committed
711
        order = np.sign(to_step)
712

Ultima's avatar
Ultima committed
713
        local_affected_data_length = local_to_slice_shape[0]
714
715
716
        local_affected_data_length_list = np.empty(comm.size, dtype=np.int)
        comm.Allgather(
            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],
717
            [local_affected_data_length_list, MPI.INT])
718
719
720
        local_affected_data_length_offset_list = np.append([0],
                                                           np.cumsum(
            local_affected_data_length_list[::order])[:-1])[::order]
Ultima's avatar
Ultima committed
721

722
        if np.isscalar(data_update):
Ultima's avatar
Ultima committed
723
724
            data[local_to_slice] = data_update
        else:
725
            # construct the locally adapted from_slice object
Ultima's avatar
Ultima committed
726
727
728
729
            r = comm.rank
            o = local_affected_data_length_offset_list
            l = local_affected_data_length

730
731
732
            data_update = self._enfold(data_update, sliceified)

            # parse the from_slices object
Ultima's avatar
Ultima committed
733
            if from_slices is None:
734
                from_slices = (slice(None, None, None),)
735
736
737
738
739
740
            (from_slices_start, from_slices_stop) = \
                self._backshift_and_decycle(
                                            slice_object=from_slices[0],
                                            shifted_start=0,
                                            shifted_stop=data_update.shape[0],
                                            global_length=data_update.shape[0])
Ultima's avatar
Ultima committed
741
            if from_slices_start is None:
742
743
744
                raise ValueError(about._errors.cstring(
                    "ERROR: _backshift_and_decycle should never return " +
                    "None for local_start!"))
745
746

            # parse the step sizes
747
            from_step = from_slices[0].step
Ultima's avatar
Ultima committed
748
            if from_step is None:
749
                from_step = 1
750
            elif from_step == 0:
751
                raise ValueError(about._errors.cstring(
752
                    "ERROR: from_step size == 0!"))
753

754
            localized_from_start = from_slices_start + from_step * o[r]
755
            localized_from_stop = localized_from_start + from_step * l
756
757
            if localized_from_stop < 0:
                localized_from_stop = None
758
759

            localized_from_slice = (slice(localized_from_start,
760
761
                                          localized_from_stop,
                                          from_step),)
762

763
            update_slice = localized_from_slice + from_slices[1:]
764
765

            if isinstance(data_update, distributed_data_object):
766
767
768
769
                local_data_update = data_update.get_data(
                                 key=update_slice,
                                 local_keys=True
                                 ).get_local_data(copy=copy).astype(self.dtype)
Ultima's avatar
Ultima committed
770
771
                if np.prod(np.shape(local_data_update)) != 0:
                    data[local_to_slice] = local_data_update
772
            # elif np.isscalar(data_update):
Ultima's avatar
Ultima committed
773
            #    data[local_to_slice] = data_update
774
775
            else:
                local_data_update = np.array(data_update)[update_slice]
Ultima's avatar
Ultima committed
776
                if np.prod(np.shape(local_data_update)) != 0:
777
778
779
                    data[local_to_slice] = np.array(
                                                local_data_update,
                                                copy=copy).astype(self.dtype)
780

781
    def collect_data(self, data, key, local_keys=False, **kwargs):
782
783
784
785
786
787
788
789
790
791
        # collect_data supports three types of keys
        # Case 1: key is a slicing/index tuple
        # Case 2: key is a boolean-array of the same shape as self
        # Case 3: key is a list of shape (n,), where n is
        #         0<n<len(self.shape). The entries of the list must be a
        #         scalar/list/tuple/ndarray. If not scalar the length must be
        #         the same for all of the lists. This is essentially
        #         numpy advanced indexing in one dimension, only.

        # Check which case we got:
Ultima's avatar
Ultima committed
792
        (found, found_boolean) = _infer_key_type(key)
793
        comm = self.comm
794
        if local_keys is False:
795
            return self._collect_data_primitive(data, key, found,
796
797
                                                found_boolean, **kwargs)
        else:
798
            # assert that all keys are from same type
799
            found_list = comm.allgather(found)
800
            assert(all(x == found_list[0] for x in found_list))
801
            found_boolean_list = comm.allgather(found_boolean)
802
803
804
805
806
            assert(all(x == found_boolean_list[0] for x in found_boolean_list))

            # gather the local_keys into a global key_list
            # Case 1: the keys are no distributed_data_objects
            # -> allgather does the job
807
808
            if found != 'd2o':
                key_list = comm.allgather(key)
809
810
811
            # Case 2: if the keys are distributed_data_objects, gather
            # the index of the array and build the key_list with help
            # from the librarian
812
813
814
815
816
            else:
                index_list = comm.allgather(key.index)
                key_list = map(lambda z: d2o_librarian[z], index_list)
            i = 0
            for temp_key in key_list:
817
                # build the locally fed d2o
818
                temp_d2o = self._collect_data_primitive(data, temp_key, found,
819
820
                                                        found_boolean,
                                                        **kwargs)
821
822
                # collect the data stored in the d2o to the individual target
                # rank
823
                temp_data = temp_d2o.get_full_data(target_rank=i)
824
825
826
827
                if comm.rank == i:
                    individual_data = temp_data
                i += 1
            return_d2o = distributed_data_object(
828
829
830
                local_data=individual_data,
                distribution_strategy='freeform',
                comm=self.comm)
831
            return return_d2o
832

833
    def _collect_data_primitive(self, data, key, found, found_boolean,
834
835
836
837
                                **kwargs):

        # Case 1: key is a slice-tuple. Hence, the basic indexing/slicing
        # machinery will be used
838
        if found == 'slicetuple':
839
840
            return self.collect_data_from_slices(data=data,
                                                 slice_objects=key,
841
842
                                                 **kwargs)
        # Case 2: key is an array
Ultima's avatar
Ultima committed
843
        elif (found == 'ndarray' or found == 'd2o'):
844
            # Case 2.1: The array is boolean.
845
846
847
            if found_boolean:
                return self.collect_data_from_bool(data=data,
                                                   boolean_key=key,
Ultima's avatar
Ultima committed
848
                                                   **kwargs)
849
850
            # Case 2.2: The array is not boolean. Only 1-dimensional
            # advanced slicing is supported.
Ultima's avatar
Ultima committed
851
852
853
854
855
            else:
                if len(key.shape) != 1:
                    raise ValueError(about._errors.cstring(
                        "WARNING: Only one-dimensional advanced indexing " +
                        "is supported"))
856
                # Make a recursive call in order to trigger the 'list'-section
857
                return self.collect_data(data=data, key=[key], **kwargs)
858
859
860

        # Case 3 : key is a list. This list is interpreted as one-dimensional
        # advanced indexing list.
861
        elif found == 'indexinglist':
862
863
864
            return self.collect_data_from_list(data=data,
                                               list_key=key,
                                               **kwargs)