nifty_mpi_data.py 78.6 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Theo Steininger
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.


ultimanet's avatar
ultimanet committed
24
25


26

27
##initialize the 'found-packages'-dictionary 
28
found = {}
ultimanet's avatar
ultimanet committed
29
import numpy as np
Ultimanet's avatar
Ultimanet committed
30
from nifty_about import about
ultimanet's avatar
ultimanet committed
31
32

try:
33
    from mpi4py import MPI
34
    found['MPI'] = True
ultimanet's avatar
ultimanet committed
35
except(ImportError): 
36
    import mpi_dummy as MPI
37
    found['MPI'] = False
ultimanet's avatar
ultimanet committed
38
39
40
41
42
43
44
45

try:
    import pyfftw
    found['pyfftw'] = True
except(ImportError):       
    found['pyfftw'] = False

try:
46
    import h5py
ultimanet's avatar
ultimanet committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    found['h5py'] = True
    found['h5py_parallel'] = h5py.get_config().mpi
except(ImportError):
    found['h5py'] = False
    found['h5py_parallel'] = False


class distributed_data_object(object):
    """

        NIFTY class for distributed data

        Parameters
        ----------
        global_data : {tuple, list, numpy.ndarray} *at least 1-dimensional*
            Initial data which will be casted to a numpy.ndarray and then 
            stored according to the distribution strategy. The global_data's
            shape overwrites global_shape.
        global_shape : tuple of ints, *optional*
            If no global_data is supplied, global_shape can be used to
            initialize an empty distributed_data_object
        dtype : type, *optional*
            If an explicit dtype is supplied, the given global_data will be 
            casted to it.            
        distribution_strategy : {'fftw' (default), 'not'}, *optional*
            Specifies the way, how global_data will be distributed to the 
            individual nodes. 
            'fftw' follows the distribution strategy of pyfftw.
            'not' does not distribute the data at all. 
            

        Attributes
        ----------
        data : numpy.ndarray
            The numpy.ndarray in which the individual node's data is stored.
        dtype : type
            Data type of the data object.
        distribution_strategy : string
            Name of the used distribution_strategy
        distributor : distributor
            The distributor object which takes care of all distribution and 
            consolidation of the data. 
        shape : tuple of int
            The global shape of the data
            
        Raises
        ------
        TypeError : 
            If the supplied distribution strategy is not known. 
        
    """
98
99
100
    def __init__(self, global_data = None, global_shape=None, dtype=None, 
                 distribution_strategy='fftw', hermitian=False,
                 alias=None, path=None, comm = MPI.COMM_WORLD, 
101
                 copy = True, *args, **kwargs):
102
103
104
105
106
107
108
109
        
        ## a given hdf5 file overwrites the other parameters
        if found['h5py'] == True and alias is not None:
            ## set file path            
            file_path = path if (path is not None) else alias 
            ## open hdf5 file
            if found['h5py_parallel'] == True and found['MPI'] == True:
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
110
            else:
111
112
113
114
115
116
117
                f= h5py.File(file_path, 'r')        
            ## open alias in file
            dset = f[alias] 
            ## set shape 
            global_shape = dset.shape
            ## set dtype
            dtype = dset.dtype.type
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        ## if no hdf5 path was given, extract global_shape and dtype from 
        ## the remaining arguments
        else:        
            ## an explicitly given dtype overwrites the one from global_data
            if dtype is None:
                if global_data is None:
                    raise ValueError(about._errors.cstring(
                        "ERROR: Neither global_data nor dtype supplied!"))      
                try:
                    dtype = global_data.dtype.type
                except(AttributeError):
                    try:
                        dtype = global_data.dtype
                    except(AttributeError):
                        dtype = np.array(global_data).dtype.type
            else:
135
                dtype = np.dtype(dtype).type
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            
            ## an explicitly given global_shape argument is only used if 
            ## 1. no global_data was supplied, or 
            ## 2. global_data is a scalar/list of dimension 0.
            if global_shape is None:
                if global_data is None or np.isscalar(global_data):
                    raise ValueError(about._errors.cstring(
    "ERROR: Neither non-0-dimensional global_data nor global_shape supplied!"))      
                global_shape = global_data.shape
            else:
                if global_data is None or np.isscalar(global_data):
                    global_shape = tuple(global_shape)
                else:
                    global_shape = global_data.shape
Ultimanet's avatar
Ultimanet committed
150

Ultimanet's avatar
Ultimanet committed
151

152
153
154
155
156
157
        self.distributor = distributor_factory.get_distributor(
                                distribution_strategy = distribution_strategy,
                                global_shape = global_shape,
                                dtype = dtype,
                                **kwargs)
                                
ultimanet's avatar
ultimanet committed
158
159
160
161
        self.distribution_strategy = distribution_strategy
        self.dtype = self.distributor.dtype
        self.shape = self.distributor.global_shape
        
162
163
        self.init_args = args 
        self.init_kwargs = kwargs
164
165
166
167
168
169
170
171


        ## If a hdf5 path was given, load the data
        if found['h5py'] == True and alias is not None:
            self.load(alias = alias, path = path)
            ## close the file handle
            f.close()
            
172
        ## If the input data was a scalar, set the whole array to this value
173
        elif global_data != None and np.isscalar(global_data):
174
            temp = np.empty(self.distributor.local_shape, dtype = self.dtype)
Ultimanet's avatar
Ultimanet committed
175
176
            temp.fill(global_data)
            self.set_local_data(temp)
177
            self.hermitian = True
178
179
        else:
            self.set_full_data(data=global_data, hermitian=hermitian, 
180
                               copy = copy, **kwargs)
181
            
Ultimanet's avatar
Ultimanet committed
182
183
184
185
186
187
188
189
    def copy(self, dtype=None, distribution_strategy=None, **kwargs):
        temp_d2o = self.copy_empty(dtype=dtype, 
                                   distribution_strategy=distribution_strategy, 
                                   **kwargs)     
        if distribution_strategy == None or \
            distribution_strategy == self.distribution_strategy:
            temp_d2o.set_local_data(self.get_local_data(), copy=True)
        else:
190
191
            #temp_d2o.set_full_data(self.get_full_data())
            temp_d2o.inject([slice(None),], self, [slice(None),])
192
        temp_d2o.hermitian = self.hermitian
193
194
        return temp_d2o
    
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def copy_empty(self, global_shape=None, dtype=None, 
                   distribution_strategy=None, **kwargs):
        if global_shape == None:
            global_shape = self.shape
        if dtype == None:
            dtype = self.dtype
        if distribution_strategy == None:
            distribution_strategy = self.distribution_strategy

        kwargs.update(self.init_kwargs)
        
        temp_d2o = distributed_data_object(global_shape=global_shape,
                                           dtype=dtype,
                                           distribution_strategy=distribution_strategy,
209
                                           *self.init_args,
210
                                           **kwargs)
211
212
        return temp_d2o
    
213
    def apply_scalar_function(self, function, inplace=False, dtype=None):
214
215
        remember_hermitianQ = self.hermitian
        
Ultimanet's avatar
Ultimanet committed
216
217
        if inplace == True:        
            temp = self
218
219
220
221
            if dtype != None and self.dtype != dtype:
                about.warnings.cprint(\
            "WARNING: Inplace dtype conversion is not possible!")
                
Ultimanet's avatar
Ultimanet committed
222
        else:
223
            temp = self.copy_empty(dtype=dtype)
Ultimanet's avatar
Ultimanet committed
224
225
226
227
228

        try: 
            temp.data[:] = function(self.data)
        except:
            temp.data[:] = np.vectorize(function)(self.data)
229
        
230
231
232
233
        if function in (np.exp, np.log):
            temp.hermitian = remember_hermitianQ
        else:
            temp.hermitian = False
Ultimanet's avatar
Ultimanet committed
234
235
236
237
238
239
        return temp
    
    def apply_generator(self, generator):
        self.set_local_data(generator(self.distributor.local_shape))
        self.hermitian = False
            
ultimanet's avatar
ultimanet committed
240
241
242
243
244
245
    def __str__(self):
        return self.data.__str__()
    
    def __repr__(self):
        return '<distributed_data_object>\n'+self.data.__repr__()
    
246
247
    
    def _compare_helper(self, other, op):
248
        result = self.copy_empty(dtype = np.bool_)
Ultimanet's avatar
Ultimanet committed
249
250
251
        ## Case 1: 'other' is a scalar
        ## -> make point-wise comparison
        if np.isscalar(other):
252
253
            result.set_local_data(
                    getattr(self.get_local_data(copy = False), op)(other))
Ultimanet's avatar
Ultimanet committed
254
255
256
257
258
259
260
            return result        

        ## Case 2: 'other' is a numpy array or a distributed_data_object
        ## -> extract the local data and make point-wise comparison
        elif isinstance(other, np.ndarray) or\
        isinstance(other, distributed_data_object):
            temp_data = self.distributor.extract_local_data(other)
261
262
            result.set_local_data(
                getattr(self.get_local_data(copy=False), op)(temp_data))
Ultimanet's avatar
Ultimanet committed
263
264
265
266
267
268
269
            return result
        
        ## Case 3: 'other' is None
        elif other == None:
            return False
        
        ## Case 4: 'other' is something different
270
        ## -> make a numpy casting and make a recursive call
Ultimanet's avatar
Ultimanet committed
271
272
        else:
            temp_other = np.array(other)
273
            return getattr(self, op)(temp_other)
Ultimanet's avatar
Ultimanet committed
274
        
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

    def __ne__(self, other):
        return self._compare_helper(other, '__ne__')
        
    def __lt__(self, other):
        return self._compare_helper(other, '__lt__')
            
    def __le__(self, other):
        return self._compare_helper(other, '__le__')

    def __eq__(self, other):

        return self._compare_helper(other, '__eq__')
    def __ge__(self, other):
        return self._compare_helper(other, '__ge__')

    def __gt__(self, other):
        return self._compare_helper(other, '__gt__')

Ultimanet's avatar
Ultimanet committed
294
    def equal(self, other):
Ultimanet's avatar
Ultimanet committed
295
296
297
298
299
300
301
302
303
        if other is None:
            return False
        try:
            assert(self.dtype == other.dtype)
            assert(self.shape == other.shape)
            assert(self.init_args == other.init_args)
            assert(self.init_kwargs == other.init_kwargs)
            assert(self.distribution_strategy == other.distribution_strategy)
            assert(np.all(self.data == other.data))
Ultimanet's avatar
Ultimanet committed
304
        except(AssertionError, AttributeError):
Ultimanet's avatar
Ultimanet committed
305
306
307
308
309
310
311
            return False
        else:
            return True
        

            
    
312
    def __pos__(self):
313
        temp_d2o = self.copy_empty()
314
        temp_d2o.set_local_data(data = self.get_local_data(), copy = True)
315
316
        return temp_d2o
        
ultimanet's avatar
ultimanet committed
317
    def __neg__(self):
318
        temp_d2o = self.copy_empty()
319
320
        temp_d2o.set_local_data(data = self.get_local_data().__neg__(),
                                copy = True) 
ultimanet's avatar
ultimanet committed
321
322
        return temp_d2o
    
323
    def __abs__(self):
Ultimanet's avatar
Ultimanet committed
324
325
326
327
328
329
330
331
332
333
334
335
        ## translate complex dtypes
        if self.dtype == np.complex64:
            new_dtype = np.float32
        elif self.dtype == np.complex128:
            new_dtype = np.float64
        elif self.dtype == np.complex:
            new_dtype = np.float
        elif issubclass(self.dtype, np.complexfloating):
            new_dtype = np.float
        else:
            new_dtype = self.dtype
        temp_d2o = self.copy_empty(dtype = new_dtype)
336
337
        temp_d2o.set_local_data(data = self.get_local_data().__abs__(),
                                copy = True) 
338
        return temp_d2o
ultimanet's avatar
ultimanet committed
339
            
340
    def __builtin_helper__(self, operator, other, inplace=False):
Ultimanet's avatar
Ultimanet committed
341
342
343
344
345
        ## Case 1: other is not a scalar
        if not (np.isscalar(other) or np.shape(other) == (1,)):
##            if self.shape != other.shape:            
##                raise AttributeError(about._errors.cstring(
##                    "ERROR: Shapes do not match!")) 
346
            try:            
347
                hermitian_Q = (other.hermitian and self.hermitian)
348
349
            except(AttributeError):
                hermitian_Q = False
Ultimanet's avatar
Ultimanet committed
350
351
352
            ## extract the local data from the 'other' object
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
Ultimanet's avatar
Ultimanet committed
353
            
354
355
356
357
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other) or (self.dtype not in (np.complex, np.complex128,
                                                np.complex256)):
            hermitian_Q = self.hermitian
ultimanet's avatar
ultimanet committed
358
            temp_data = operator(other)
359
360
361
362
        ## Case 3: other is complex
        else:
            hermitian_Q = False
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
363
        ## write the new data into a new distributed_data_object        
364
365
366
        if inplace == True:
            temp_d2o = self
        else:
367
368
369
370
371
            ## use common datatype for self and other
            new_dtype = np.dtype(np.find_common_type((self.dtype,),
                                                     (temp_data.dtype,))).type
            temp_d2o = self.copy_empty(
                            dtype = new_dtype)
ultimanet's avatar
ultimanet committed
372
        temp_d2o.set_local_data(data=temp_data)
373
        temp_d2o.hermitian = hermitian_Q
ultimanet's avatar
ultimanet committed
374
        return temp_d2o
375
    """
Ultimanet's avatar
Ultimanet committed
376
    def __inplace_builtin_helper__(self, operator, other):
377
        ## Case 1: other is not a scalar
Ultimanet's avatar
Ultimanet committed
378
379
380
        if not (np.isscalar(other) or np.shape(other) == (1,)):        
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
381
382
383
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other):
            hermitian_Q = self.hermitian
Ultimanet's avatar
Ultimanet committed
384
            temp_data = operator(other)
385
386
387
        ## Case 3: other is complex
        else:
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
388
        self.set_local_data(data=temp_data)
389
        self.hermitian = hermitian_Q
Ultimanet's avatar
Ultimanet committed
390
        return self
391
    """ 
Ultimanet's avatar
Ultimanet committed
392
    
ultimanet's avatar
ultimanet committed
393
394
395
396
397
    def __add__(self, other):
        return self.__builtin_helper__(self.get_local_data().__add__, other)

    def __radd__(self, other):
        return self.__builtin_helper__(self.get_local_data().__radd__, other)
Ultimanet's avatar
Ultimanet committed
398
399

    def __iadd__(self, other):
400
401
402
        return self.__builtin_helper__(self.get_local_data().__iadd__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
403

ultimanet's avatar
ultimanet committed
404
405
406
407
408
409
410
    def __sub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__sub__, other)
    
    def __rsub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rsub__, other)
    
    def __isub__(self, other):
411
412
413
        return self.__builtin_helper__(self.get_local_data().__isub__, 
                                               other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
414
415
416
417
        
    def __div__(self, other):
        return self.__builtin_helper__(self.get_local_data().__div__, other)
    
418
419
420
    def __truediv__(self, other):
        return self.__div__(other)
        
ultimanet's avatar
ultimanet committed
421
422
    def __rdiv__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rdiv__, other)
423
424
425
    
    def __rtruediv__(self, other):
        return self.__rdiv__(other)
ultimanet's avatar
ultimanet committed
426

Ultimanet's avatar
Ultimanet committed
427
    def __idiv__(self, other):
428
429
430
        return self.__builtin_helper__(self.get_local_data().__idiv__, 
                                               other,
                                               inplace = True)
431
    def __itruediv__(self, other):
432
433
        return self.__idiv__(other)
                                               
ultimanet's avatar
ultimanet committed
434
    def __floordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
435
436
        return self.__builtin_helper__(self.get_local_data().__floordiv__, 
                                       other)    
ultimanet's avatar
ultimanet committed
437
    def __rfloordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
438
439
440
        return self.__builtin_helper__(self.get_local_data().__rfloordiv__, 
                                       other)
    def __ifloordiv__(self, other):
441
442
443
        return self.__builtin_helper__(
                    self.get_local_data().__ifloordiv__, other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
444
445
446
447
448
449
450
451
    
    def __mul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__mul__, other)
    
    def __rmul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rmul__, other)

    def __imul__(self, other):
452
453
454
        return self.__builtin_helper__(self.get_local_data().__imul__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
455

ultimanet's avatar
ultimanet committed
456
457
458
459
460
461
462
    def __pow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__pow__, other)
 
    def __rpow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rpow__, other)

    def __ipow__(self, other):
463
        return self.__builtin_helper__(self.get_local_data().__ipow__, 
464
465
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
466
   
467
468
    def __len__(self):
        return self.shape[0]
469
    
470
    def get_dim(self):
471
472
        return np.prod(self.shape)
        
473
    def vdot(self, other):
474
        other = self.distributor.extract_local_data(other)
475
476
477
478
479
        local_vdot = np.vdot(self.get_local_data(), other)
        local_vdot_list = self.distributor._allgather(local_vdot)
        global_vdot = np.sum(local_vdot_list)
        return global_vdot
            
Ultimanet's avatar
Ultimanet committed
480

481
    
ultimanet's avatar
ultimanet committed
482
    def __getitem__(self, key):
Ultimanet's avatar
Ultimanet committed
483
484
485
486
487
        ## Case 1: key is a boolean array.
        ## -> take the local data portion from key, use this for data 
        ## extraction, and then merge the result in a flat numpy array
        if isinstance(key, np.ndarray):
            found = 'ndarray'
488
            found_boolean = (key.dtype.type == np.bool_)
Ultimanet's avatar
Ultimanet committed
489
490
        elif isinstance(key, distributed_data_object):
            found = 'd2o'
491
            found_boolean = (key.dtype == np.bool_)
Ultimanet's avatar
Ultimanet committed
492
493
        else:
            found = 'other'
Ultima's avatar
Ultima committed
494
        ## TODO: transfer this into distributor:
Ultimanet's avatar
Ultimanet committed
495
496
497
498
499
500
501
502
503
504
        if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
            ## extract the data of local relevance
            local_bool_array = self.distributor.extract_local_data(key)
            local_results = self.get_local_data(copy=False)[local_bool_array]
            global_results = self.distributor._allgather(local_results)
            global_results = np.concatenate(global_results)
            return global_results            
            
        else:
            return self.get_data(key)
ultimanet's avatar
ultimanet committed
505
506
507
508
    
    def __setitem__(self, key, data):
        self.set_data(data, key)
        
509
    def _contraction_helper(self, function, **kwargs):
510
511
512
513
514
515
        local = function(self.data, **kwargs)
        local_list = self.distributor._allgather(local)
        global_ = function(local_list, axis=0)
        return global_
        
    def amin(self, **kwargs):
516
        return self._contraction_helper(np.amin, **kwargs)
517
518

    def nanmin(self, **kwargs):
519
        return self._contraction_helper(np.nanmin, **kwargs)
520
521
        
    def amax(self, **kwargs):
522
        return self._contraction_helper(np.amax, **kwargs)
523
524
    
    def nanmax(self, **kwargs):
525
        return self._contraction_helper(np.nanmax, **kwargs)
Ultimanet's avatar
Ultimanet committed
526
    
527
528
529
530
531
532
    def sum(self, **kwargs):
        return self._contraction_helper(np.sum, **kwargs)

    def prod(self, **kwargs):
        return self._contraction_helper(np.prod, **kwargs)        
        
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
    def mean(self, power=1):
        ## compute the local means and the weights for the mean-mean. 
        local_mean = np.mean(self.data**power)
        local_weight = np.prod(self.data.shape)
        ## collect the local means and cast the result to a ndarray
        local_mean_weight_list = self.distributor._allgather((local_mean, 
                                                              local_weight))
        local_mean_weight_list =np.array(local_mean_weight_list)   
        ## compute the denominator for the weighted mean-mean                                                           
        global_weight = np.sum(local_mean_weight_list[:,1])
        ## compute the numerator
        numerator = np.sum(local_mean_weight_list[:,0]*\
            local_mean_weight_list[:,1])
        global_mean = numerator/global_weight
        return global_mean

    def var(self):
        mean_of_the_square = self.mean(power=2)
        square_of_the_mean = self.mean()**2
        return mean_of_the_square - square_of_the_mean
    
    def std(self):
        return np.sqrt(self.var())
        
557
558
559
560
561
562
563
564
565
566
567
#    def _argmin_argmax_flat_helper(self, function):
#        local_argmin = function(self.data)
#        local_argmin_value = self.data[np.unravel_index(local_argmin, 
#                                                        self.data.shape)]
#        globalized_local_argmin = self.distributor.globalize_flat_index(local_argmin)                                                       
#        local_argmin_list = self.distributor._allgather((local_argmin_value, 
#                                                         globalized_local_argmin))
#        local_argmin_list = np.array(local_argmin_list, dtype=[('value', int),
#                                                               ('index', int)])    
#        return local_argmin_list
#        
568
569
570
571
    def argmin_flat(self):
        local_argmin = np.argmin(self.data)
        local_argmin_value = self.data[np.unravel_index(local_argmin, 
                                                        self.data.shape)]
572
573
        globalized_local_argmin = self.distributor.globalize_flat_index(
                                                                local_argmin)                                                       
574
        local_argmin_list = self.distributor._allgather((local_argmin_value, 
575
576
577
578
579
580
                                                    globalized_local_argmin))
        local_argmin_list = np.array(local_argmin_list, dtype=[
                                        ('value', local_argmin_value.dtype),
                                        ('index', int)])    
        local_argmin_list = np.sort(local_argmin_list, 
                                    order=['value', 'index'])        
581
582
583
584
585
586
        return local_argmin_list[0][1]
    
    def argmax_flat(self):
        local_argmax = np.argmax(self.data)
        local_argmax_value = -self.data[np.unravel_index(local_argmax, 
                                                        self.data.shape)]
587
588
        globalized_local_argmax = self.distributor.globalize_flat_index(
                                                                local_argmax)                                                       
589
        local_argmax_list = self.distributor._allgather((local_argmax_value, 
590
591
592
593
594
595
                                                    globalized_local_argmax))
        local_argmax_list = np.array(local_argmax_list, dtype=[
                                        ('value', local_argmax_value.dtype),
                                        ('index', int)]) 
        local_argmax_list = np.sort(local_argmax_list, 
                                    order=['value', 'index'])        
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
        return local_argmax_list[0][1]
        

    def argmin(self):    
        return np.unravel_index(self.argmin_flat(), self.shape)
    
    def argmax(self):
        return np.unravel_index(self.argmax_flat(), self.shape)
    
    def conjugate(self):
        temp_d2o = self.copy_empty()
        temp_data = np.conj(self.get_local_data())
        temp_d2o.set_local_data(temp_data)
        return temp_d2o

    
    def conj(self):
        return self.conjugate()      
        
    def median(self):
Ultimanet's avatar
Ultimanet committed
616
        about.warnings.cprint(\
617
618
619
620
            "WARNING: The current implementation of median is very expensive!")
        median = np.median(self.get_full_data())
        return median
        
621
    def iscomplex(self):
622
        temp_d2o = self.copy_empty(dtype=np.bool_)
623
624
625
626
        temp_d2o.set_local_data(np.iscomplex(self.data))
        return temp_d2o
    
    def isreal(self):
627
        temp_d2o = self.copy_empty(dtype=np.bool_)
628
629
630
        temp_d2o.set_local_data(np.isreal(self.data))
        return temp_d2o
    
631

632
633
634
635
636
637
638
639
    def all(self):
        local_all = np.all(self.get_local_data())
        global_all = self.distributor._allgather(local_all)
        return np.all(global_all)

    def any(self):
        local_any = np.any(self.get_local_data())
        global_any = self.distributor._allgather(local_any)
640
        return np.any(global_any)
641
        
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
    def unique(self):
        local_unique = np.unique(self.get_local_data())
        global_unique = self.distributor._allgather(local_unique)
        global_unique = np.concatenate(global_unique)
        return np.unique(global_unique)
        
    def bincount(self, weights = None, minlength = None):
        if np.dtype(self.dtype).type not in [np.int8, np.int16, np.int32, 
                np.int64, np.uint8, np.uint16, np.uint32, np.uint64]:
            raise TypeError(about._errors.cstring(
                "ERROR: Distributed-data-object must be of integer datatype!"))                                                
                
        minlength = max(self.amax()+1, minlength)
        
        if weights is not None:
            local_weights = self.distributor.extract_local_data(weights).\
                                                                    flatten()
        else:
            local_weights = None
            
        local_counts = np.bincount(self.get_local_data().flatten(),
                                  weights = local_weights,
                                  minlength = minlength)
        list_of_counts = self.distributor._allgather(local_counts)
        print list_of_counts 
        counts = np.sum(list_of_counts, axis = 0)
        return counts
                              
670
    
671
    def set_local_data(self, data, hermitian=False, copy=True):
ultimanet's avatar
ultimanet committed
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        """
            Stores data directly in the local data attribute. No distribution 
            is done. The shape of the data must fit the local data attributes
            shape.

            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be stored in the local data attribute.
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
687
688
        self.hermitian = hermitian
        self.data = np.array(data, dtype=self.dtype, copy=copy, order='C')
ultimanet's avatar
ultimanet committed
689
    
690
    def set_data(self, data, key, hermitian=False, copy=True, *args, **kwargs):
ultimanet's avatar
ultimanet committed
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        """
            Stores the supplied data in the region which is specified by key. 
            The data is distributed according to the distribution strategy. If
            the individual nodes get different key-arguments. Their data is 
            processed one-by-one.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be stored in.                
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
710
        self.hermitian = hermitian
ultimanet's avatar
ultimanet committed
711
        (slices, sliceified) = self.__sliceify__(key)        
Ultimanet's avatar
Ultimanet committed
712
713
        self.distributor.disperse_data(data=self.data, 
                        to_slices = slices,
714
715
                        data_update = self.__enfold__(data, sliceified),
                        copy = copy,
Ultimanet's avatar
Ultimanet committed
716
                        *args, **kwargs)        
ultimanet's avatar
ultimanet committed
717
    
718
    def set_full_data(self, data, hermitian=False, copy = True, **kwargs):
ultimanet's avatar
ultimanet committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        """
            Distributes the supplied data to the nodes. The shape of data must 
            match the shape of the distributed_data_object.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            
            Notes
            -----
            set_full_data(foo) is equivalent to set_data(foo,slice(None)) but 
            faster.
        
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
738
        self.hermitian = hermitian
739
740
        self.data = self.distributor.distribute_data(data=data, copy = copy, 
                                                     **kwargs)
ultimanet's avatar
ultimanet committed
741
742
    

Ultimanet's avatar
Ultimanet committed
743
    def get_local_data(self, key=(slice(None),), copy=True):
ultimanet's avatar
ultimanet committed
744
745
746
747
748
749
750
751
752
753
754
755
756
        """
            Loads data directly from the local data attribute. No consolidation 
            is done. 

            Parameters
            ----------
            key : int, slice, tuple of int or slice
                The key which will be used to access the data. 
            
            Returns
            -------
            self.data[key] : numpy.ndarray
        
Ultimanet's avatar
Ultimanet committed
757
        """
Ultimanet's avatar
Ultimanet committed
758
759
760
761
        if copy == True:
            return self.data[key]        
        if copy == False:
            return self.data
ultimanet's avatar
ultimanet committed
762
        
763
    def get_data(self, key, **kwargs):
ultimanet's avatar
ultimanet committed
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        """
            Loads data from the region which is specified by key. The data is 
            consolidated according to the distribution strategy. If the 
            individual nodes get different key-arguments, they get individual
            data. 
            
            Parameters
            ----------
        
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be loaded from.                 
            
            Returns
            -------
            global_data[key] : numpy.ndarray
        
        """
782
783
        (slices, sliceified) = self.__sliceify__(key)
        result = self.distributor.collect_data(self.data, slices, **kwargs)        
ultimanet's avatar
ultimanet committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
        return self.__defold__(result, sliceified)
        
    
    
    def get_full_data(self, target_rank='all'):
        """
            Fully consolidates the distributed data. 
            
            Parameters
            ----------
            target_rank : 'all' (default), int *optional*
                If only one node should recieve the full data, it can be 
                specified here.
            
            Notes
            -----
            get_full_data() is equivalent to get_data(slice(None)) but 
            faster.
        
            Returns
            -------
            None
        """

808
809
        return self.distributor.consolidate_data(self.data, 
                                                 target_rank = target_rank)
ultimanet's avatar
ultimanet committed
810

Ultimanet's avatar
Ultimanet committed
811
812
813
814
815
816
817
    def inject(self, to_slices=(slice(None),), data=None, 
               from_slices=(slice(None),)):
        if data == None:
            return self
        
        self.distributor.inject(self.data, to_slices, data, from_slices)
        
818
819
820
821
822
823
824
825
826
827
828
    def flatten(self, inplace = False):
        flat_shape = (np.prod(self.shape),)
        temp_d2o = self.copy_empty(global_shape = flat_shape)
        flat_data = self.distributor.flatten(self.data, inplace = inplace)
        temp_d2o.set_local_data(data = flat_data)
        if inplace == True:
            self = temp_d2o
            return self
        else:
            return temp_d2o
        
Ultimanet's avatar
Ultimanet committed
829
        
830

ultimanet's avatar
ultimanet committed
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
      
    def save(self, alias, path=None, overwriteQ=True):
        
        """
            Saves a distributed_data_object to disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name for the dataset which is saved within the hdf5 file.
         
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
            
            overwriteQ : Boolean *optional*
                Specifies whether a dataset may be overwritten if it is already
                present in the given hdf5 file or not.
        """
        self.distributor.save_data(self.data, alias, path, overwriteQ)

    def load(self, alias, path=None):
        """
            Loads a distributed_data_object from disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name of the dataset which is loaded from the hdf5 file.
 
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
        """
        self.data = self.distributor.load_data(alias, path)
           
    def __sliceify__(self, inp):
        sliceified = []
        result = []
        if isinstance(inp, tuple):
            x = inp
Ultimanet's avatar
Ultimanet committed
872
873
        elif isinstance(inp, list):
            x = tuple(inp)
ultimanet's avatar
ultimanet committed
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
        else:
            x = (inp, )
        
        for i in range(len(x)):
            if isinstance(x[i], slice):
                result += [x[i], ]
                sliceified += [False, ]
            else:
                result += [slice(x[i], x[i]+1), ]
                sliceified += [True, ]
    
        return (tuple(result), sliceified)
                
                
    def __enfold__(self, in_data, sliceified):
        data = np.array(in_data, copy=False)    
        temp_shape = ()
        j=0
        for i in sliceified:
            if i == True:
                temp_shape += (1,)
895
896
897
898
899
                try:
                    if data.shape[j] == 1:
                        j +=1
                except(IndexError):
                    pass
ultimanet's avatar
ultimanet committed
900
            else:
901
902
903
904
                try:
                    temp_shape += (data.shape[j],)
                except(IndexError):
                    temp_shape += (1,)
ultimanet's avatar
ultimanet committed
905
906
907
908
909
910
                j += 1
        ## take into account that the sliceified tuple may be too short, because 
        ## of a non-exaustive list of slices
        for i in range(len(data.shape)-j):
            temp_shape += (data.shape[j],)
            j += 1
Ultimanet's avatar
Ultimanet committed
911
        
ultimanet's avatar
ultimanet committed
912
913
914
915
916
917
918
919
920
921
922
923
        return data.reshape(temp_shape)
    
    def __defold__(self, data, sliceified):
        temp_slice = ()
        for i in sliceified:
            if i == True:
                temp_slice += (0,)
            else:
                temp_slice += (slice(None),)
        return data[temp_slice]

    
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
class _distributor_factory(object):
    '''
        Comments:
          - The distributor's get_data and set_data functions MUST be 
            supplied with a tuple of slice objects. In case that there was 
            a direct integer involved, the unfolding will be done by the
            helper functions __sliceify__, __enfold__ and __defold__.
    '''
    def __init__(self):
        self.distributor_store = {}
    
    def parse_kwargs(self, strategy = None, kwargs = {}):
        return_dict = {}
        if strategy == 'not':
            pass
        if strategy == 'fftw' or strategy == 'equal':
            if kwargs.has_key('comm'):
                return_dict['comm'] = kwargs['comm']
        return return_dict
                        
944
945
    def hash_arguments(self, global_shape, dtype, distribution_strategy,
                       kwargs={}):
946
947
948
949
950
        kwargs = kwargs.copy()
        if kwargs.has_key('comm'):
            kwargs['comm'] = id(kwargs['comm'])
        kwargs['global_shape'] = global_shape        
        kwargs['dtype'] = self.dictionize_np(dtype)
951
        kwargs['distribution_strategy'] = distribution_strategy
952
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
953

954
955
956
957
958
959
960
961
962
963
964
    def dictionize_np(self, x):
        dic = x.__dict__.items()
        if x is np.float:
            dic[24] = 0 
            dic[29] = 0
            dic[37] = 0
        return frozenset(dic)            
            
    def get_distributor(self, distribution_strategy, global_shape, dtype,
                        **kwargs):
        ## check if the distribution strategy is known
965
966
967
968
        
        known_distribution_strategies = ['not', 'equal']
        if found['pyfftw'] == True and found['MPI'] == True:
            known_distribution_strategies += ['fftw',]
969
970
971
972
973
974
975
976
977
        if not distribution_strategy in ['not', 'fftw', 'equal']:
            raise TypeError(about._errors.cstring(
                "ERROR: Unknown distribution strategy supplied."))
                
        ## parse the kwargs
        parsed_kwargs = self.parse_kwargs(strategy = distribution_strategy,
                                          kwargs = kwargs)
        hashed_arguments = self.hash_arguments(global_shape = global_shape,
                                               dtype = dtype,
978
                                distribution_strategy = distribution_strategy,
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
                                               kwargs = parsed_kwargs)
        #print hashed_arguments                                               
        ## check if the distributors has already been produced in the past
        if self.distributor_store.has_key(hashed_arguments):
            return self.distributor_store[hashed_arguments]
        else:                                              
            ## produce new distributor
            if distribution_strategy == 'not':
                produced_distributor = _not_distributor(
                                                    global_shape = global_shape,
                                                    dtype = dtype)
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
                                                    slicer = _equal_slicer,
                                                    global_shape = global_shape,
                                                    dtype = dtype,
                                                    **parsed_kwargs)
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
                                                    slicer = _fftw_slicer,
                                                    global_shape = global_shape,
                                                    dtype = dtype,
                                                    **parsed_kwargs)                                                
            self.distributor_store[hashed_arguments] = produced_distributor                                             
            return self.distributor_store[hashed_arguments]
            
            
distributor_factory = _distributor_factory()
ultimanet's avatar
ultimanet committed
1007
        
1008
1009
class _slicing_distributor(object):
    
ultimanet's avatar
ultimanet committed
1010

1011
1012
    def __init__(self, slicer, global_shape=None, dtype=None, 
                 comm=MPI.COMM_WORLD):
ultimanet's avatar
ultimanet committed
1013
1014
        
        if comm.rank == 0:        
1015
1016
1017
1018
1019
            if global_shape is None:
                raise TypeError(about._errors.cstring(
                    "ERROR: No shape supplied!"))
            else:
                self.global_shape = global_shape      
ultimanet's avatar
ultimanet committed
1020
1021
        else:
            self.global_shape = None
Ultimanet's avatar
Ultimanet committed
1022
            
ultimanet's avatar
ultimanet committed
1023
1024
1025
1026
        self.global_shape = comm.bcast(self.global_shape, root = 0)
        self.global_shape = tuple(self.global_shape)
        
        if comm.rank == 0:        
1027
1028
1029
                if dtype is None:        
                    raise TypeError(about._errors.cstring(
                    "ERROR: Failed setting datatype! No datatype supplied."))
ultimanet's avatar
ultimanet committed
1030
                else:
1031
                    self.dtype = dtype                    
ultimanet's avatar
ultimanet committed
1032
1033
1034
        else:
            self.dtype=None
        self.dtype = comm.bcast(self.dtype, root=0)
1035

ultimanet's avatar
ultimanet committed
1036
        
1037
        self._my_dtype_converter = _global_dtype_converter
ultimanet's avatar
ultimanet committed
1038
1039
        
        if not self._my_dtype_converter.known_np_Q(self.dtype):
Ultimanet's avatar
Ultimanet committed
1040
            raise TypeError(about._errors.cstring(\
1041
            "ERROR: The datatype "+str(self.dtype)+" is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
1042
1043
1044

        self.mpi_dtype  = self._my_dtype_converter.to_mpi(self.dtype)
        
1045
1046
1047
1048
1049
1050
1051
1052
        #self._local_size = pyfftw.local_size(self.global_shape)
        #self.local_start = self._local_size[2]
        #self.local_end = self.local_start + self._local_size[1]
        self.slicer = lambda global_shape: slicer(global_shape, comm = comm)
        self._local_size = self.slicer(self.global_shape)
        self.local_start = self._local_size[0]
        self.local_end = self._local_size[1] 
        
ultimanet's avatar
ultimanet committed
1053
1054
1055
1056
        self.local_length = self.local_end-self.local_start        
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
1057
1058
        comm.Allgather([np.array(self.local_dim,dtype=np.int), MPI.INT],\
            [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
1059
1060
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
        
1061
1062
1063
        self.local_slice = np.array([self.local_start, self.local_end,\
            self.local_length, self.local_dim, self.local_dim_offset],\
            dtype=np.int)
ultimanet's avatar
ultimanet committed
1064
1065
1066
        ## collect all local_slices 
        ## [start, stop, length=stop-start, dimension, dimension_offset]
        self.all_local_slices = np.empty((comm.size,5),dtype=np.int)
1067
1068
        comm.Allgather([np.array((self.local_slice,),dtype=np.int), MPI.INT],\
            [self.all_local_slices, MPI.INT])
ultimanet's avatar
ultimanet committed
1069
        
1070
        self.comm = comm
ultimanet's avatar
ultimanet committed
1071
        
1072
1073
1074
1075
1076
1077
    def globalize_flat_index(self, index):
        return int(index)+self.local_dim_offset
        
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
1078
            raise TypeError(about._errors.cstring("ERROR: Length\
1079
1080
1081
1082
1083
1084
1085
1086
1087
                of index tuple does not match the array's shape!"))                 
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
        ## ensure that the globalized index list is within the bounds
        global_index_memory = globalized_index
        globalized_index = np.clip(globalized_index, 
                                   -np.array(self.global_shape),
                                    np.array(self.global_shape)-1)
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
1088
            about.warnings.cprint("WARNING: Indices were clipped!")
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
        globalized_index = tuple(globalized_index)
        return globalized_index
    
    def _allgather(self, thing, comm=None):
        if comm == None:
            comm = self.comm            
        gathered_things = comm.allgather(thing)
        return gathered_things
    
    def distribute_data(self, data=None, comm = None, alias=None,
1099
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
1100
1101
1102
1103
1104
        '''
        distribute data checks 
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
1105
1106
        if comm == None:
            comm = self.comm            
1107
1108
        rank = comm.Get_rank()
        size = comm.Get_size()        
1109
        local_data_available_Q = np.array((int(data is not None), ))
1110
        data_available_Q = np.empty(size,dtype=int)
1111
1112
        comm.Allgather([local_data_available_Q, MPI.INT], 
                       [data_available_Q, MPI.INT])        
1113
1114
        
        if data_available_Q[0]==False and found['h5py']:
ultimanet's avatar
ultimanet committed
1115
1116
1117
1118
1119
1120
1121
            try: 
                file_path = path if path != None else alias 
                if found['h5py_parallel']:
                    f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
                else:
                    f= h5py.File(file_path, 'r')        
                dset = f[alias]
1122
1123
                if dset.shape == self.global_shape and \
                 dset.dtype.type == self.dtype:
ultimanet's avatar
ultimanet committed
1124
1125
1126
1127
                    temp_data = dset[self.local_start:self.local_end]
                    f.close()
                    return temp_data
                else:
Ultimanet's avatar
Ultimanet committed
1128
                    raise TypeError(about._errors.cstring("ERROR: \
1129
                    Input data has the wrong shape or wrong dtype!"))                 
ultimanet's avatar
ultimanet committed
1130
1131
1132
            except(IOError, AttributeError):
                pass
            
1133
        if np.all(data_available_Q==False):
Ultimanet's avatar
Ultimanet committed
1134
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
ultimanet's avatar
ultimanet committed
1135
        ## if all nodes got data, we assume that it is the right data and 
1136
1137
        ## store it individually. If not, take the data on node 0 and scatter 
        ## it...
ultimanet's avatar
ultimanet committed
1138
        if np.all(data_available_Q):
1139
            return data[self.local_start:self.local_end].astype(self.dtype,\
1140
                copy=copy)    
1141
1142
        ## ... but only if node 0 has actually data!
        elif data_available_Q[0] == False:# or np.all(data_available_Q==False):
Ultimanet's avatar
Ultimanet committed
1143
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
1144
        
ultimanet's avatar
ultimanet committed
1145
1146
        else:
            if data == None:
1147
                data = np.empty(self.global_shape, dtype = self.dtype)            
ultimanet's avatar
ultimanet committed
1148
1149
            if rank == 0:
                if np.all(data.shape != self.global_shape):
Ultimanet's avatar
Ultimanet committed
1150
                    raise TypeError(about._errors.cstring(\
1151
                        "ERROR: Input data has the wrong shape!"))
ultimanet's avatar
ultimanet committed
1152
            ## Scatter the data!            
Ultimanet's avatar
Ultimanet committed
1153
            _scattered_data = np.empty(self.local_shape, dtype = self.dtype)
ultimanet's avatar
ultimanet committed
1154
1155
            _dim_list = self.all_local_slices[:,3]
            _dim_offset_list = self.all_local_slices[:,4]
1156
1157
            comm.Scatterv([data, _dim_list, _dim_offset_list, self.mpi_dtype],\
                [_scattered_data, self.mpi_dtype], root=0)
ultimanet's avatar
ultimanet committed
1158
1159
1160
            return _scattered_data
        return None
    
1161
1162
1163
1164
1165
1166
    

    
    
    def disperse_data(self, data, to_slices, data_update, from_slices=None,
                      comm=None, copy = True, **kwargs):
1167
1168
        if comm == None:
            comm = self.comm            
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        to_slices_list = comm.allgather(to_slices)
        ## check if all slices are the same. 
        if all(x == to_slices_list[0] for x in to_slices_list):
            ## in this case, the _disperse_data_primitive can simply be called 
            ##with target_rank = 'all'
            self._disperse_data_primitive(data = data, 
                                          to_slices = to_slices,
                                          data_update=data_update,
                                          from_slices=from_slices, 
                                          source_rank='all', 
                                          comm=comm,
                                          copy = copy)
        ## if the different nodes got different slices, disperse the data 
        ## individually
        else:
            i = 0        
            for temp_to_slices in to_slices_list:
                ## make the collect_data call on all nodes            
                self._disperse_data_primitive(data=data,
                                              to_slices=temp_to_slices,
                                              data_update=data_update,
                                              from_slices=from_slices,
                                              source_rank=i, 
                                              comm=comm,
                                              copy = copy)
                i += 1
                 
        
#    def _disperse_data_primitive(self, data, to_slices, data_update, 
#                        from_slices, source_rank='all', comm=None, copy=True):
#        ## compute the part of the to_slice which is relevant for the 
#        ## individual node      
#        localized_to_start, localized_to_stop = self._backshift_and_decycle(
#            to_slices[0], self.local_start, self.local_end,\
#                self.global_shape[0])
#        local_to_slice = (slice(localized_to_start, localized_to_stop,\
#                        to_slices[0].step),) + to_slices[1:]
#                        
#        ## compute the parameter sets and list for the data splitting
#        local_slice_shape = data[local_slice].shape        
#        local_affected_data_length = local_slice_shape[0]
#        local_affected_data_length_list=np.empty(comm.size, dtype=np.int)        
#        comm.Allgather(\
#            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],\
#            [local_affected_data_length_list, MPI.INT])        
#        local_affected_data_length_offset_list = np.append([0],\
#                            np.cumsum(local_affected_data_length_list)[:-1])
#
#    
    
    def _disperse_data_primitive(self, data, to_slices, data_update, 
                        from_slices, source_rank='all', comm=None, copy=True):
        if comm == None:
            comm = self.comm         
    
#        if to_slices[0].step is not None and to_slices[0].step < -1:
#            raise ValueError(about._errors.cstring(
#                "ERROR: Negative stepsizes other than -1 are not supported!"))

        ## parse the to_slices object
        localized_to_start, localized_to_stop=self._backshift_and_decycle(
Ultimanet's avatar
Ultimanet committed
1230
            to_slices[0], self.local_start, self.local_end,\
1231
                self.global_shape[0])
1232
1233
1234
1235
        local_to_slice = (slice(localized_to_start, localized_to_stop,\
                        to_slices[0].step),) + to_slices[1:]   
        local_to_slice_shape = data[local_to_slice].shape        

ultimanet's avatar
ultimanet committed
1236
        if source_rank == 'all':
1237
        
Ultimanet's avatar
Ultimanet committed
1238
            
1239
            ## parse the from_slices object
Ultimanet's avatar
Ultimanet committed
1240
            if from_slices == None:
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                from_slices = (slice(None, None, None),) 
            (from_slices_start, from_slices_stop)=self._backshift_and_decycle(
                                        slice_object = from_slices[0],
                                        shifted_start = 0,
                                        shifted_stop = data_update.shape[0],
                                        global_length = data_update.shape[0])
            if from_slices_start == None:
                raise ValueError(about._errors.cstring(\
                        "ERROR: _backshift_and_decycle should never return "+\
                        "None for local_start!"))
1251
                        
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275


            ## parse the step sizes
            from_step = from_slices[0].step
            if from_step == None:
                from_step = 1
            elif from_step == 0:            
                raise ValueError(about._errors.cstring(\
                    "ERROR: from_step size == 0!"))

            to_step = to_slices[0].step
            if to_step == None:
                to_step = 1
            elif to_step == 0:            
                raise ValueError(about._errors.cstring(\
                    "ERROR: to_step size == 0!"))


            
            ## Compute the offset of the data the individual node will take.
            ## The offset is free of stepsizes. It is the offset in terms of 
            ## the purely transported data. If to_step < 0, the offset will
            ## be calculated in reverse order
            order = np.sign(to_step)
1276
            
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
            local_affected_data_length = local_to_slice_shape[0]
            local_affected_data_length_list=np.empty(comm.size, dtype=np.int)        
            comm.Allgather(\
                [np.array(local_affected_data_length, dtype=np.int), MPI.INT],\
                [local_affected_data_length_list, MPI.INT])        
            local_affected_data_length_offset_list = np.append([0],\
                np.cumsum(
                    local_affected_data_length_list[::order])[:-1])[::order]                         
  
            ## construct the locally adapted from_slice object
            r = comm.rank
            o = local_affected_data_length_offset_list
            l = local_affected_data_length
            
            localized_from_start = from_slices_start + from_step * o[r]
            localized_from_stop = localized_from_start + from_step * l            
            if localized_from_stop < 0:
                localized_from_stop = None
1295
                
1296
1297
1298
1299
1300
1301
1302
1303
1304
            localized_from_slice = (slice(localized_from_start, 
                                  localized_from_stop, 
                                  from_step),)
                                          
            update_slice = localized_from_slice + from_slices[1:]
            data[local_to_slice] = np.array(data_update[update_slice],\
                                    copy=copy).astype(self.dtype)
                
            
ultimanet's avatar
ultimanet committed
1305
1306
1307
        else:
            ## Scatterv the relevant part from the source_rank to the others 
            ## and plug it into data[local_slice]
1308
1309
1310
            
            ## if the first slice object has a negative step size, the ordering 
            ## of the Scatterv function must be reversed         
Ultimanet's avatar
Ultimanet committed
1311
            order = to_slices[0].step
1312
1313
1314
1315
1316
1317
1318
            if order == None:
                order = 1
            else:
                order = np.sign(order)

            local_affected_data_dim_list = \
                np.array(local_affected_data_length_list) *\
1319
                    np.product(local_to_slice_shape[1:])                    
1320
1321
1322
1323

            local_affected_data_dim_offset_list = np.append([0],\
                np.cumsum(local_affected_data_dim_list[::order])[:-1])[::order]
                
1324
            local_dispersed_data = np.zeros(local_to_slice_shape,\
1325
1326
                dtype=self.dtype)
            comm.Scatterv(\
1327
                [np.array(data_update[from_slices], copy=False).\
Ultimanet's avatar
Ultimanet committed
1328
                                                        astype(self.dtype),\
1329
1330
                    local_affected_data_dim_list,\
                    local_affected_data_dim_offset_list, self.mpi_dtype],
ultimanet's avatar
ultimanet committed
1331
1332
                          [local_dispersed_data, self.mpi_dtype], 
                          root=source_rank)                            
1333
            data[local_to_slice] = local_dispersed_data
ultimanet's avatar
ultimanet committed
1334
1335
        return None
        
1336
1337

    def collect_data(self, data, slice_objects, comm=None, **kwargs):
1338
        if comm == None:
1339
1340
            comm = self.comm                    
        slice_objects_list = comm.allgather(slice_objects)
ultimanet's avatar
ultimanet committed
1341
        ## check if all slices are the same. 
1342
1343
        if all