nifty_mpi_data.py 95 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Theo Steininger
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.


ultimanet's avatar
ultimanet committed
24
25


26

Ultima's avatar
Ultima committed
27
28
##initialize the 'FOUND-packages'-dictionary 
FOUND = {}
ultimanet's avatar
ultimanet committed
29
import numpy as np
Ultimanet's avatar
Ultimanet committed
30
from nifty_about import about
ultimanet's avatar
ultimanet committed
31
32

try:
33
    from mpi4py import MPI
Ultima's avatar
Ultima committed
34
    FOUND['MPI'] = True
ultimanet's avatar
ultimanet committed
35
except(ImportError): 
36
    import mpi_dummy as MPI
Ultima's avatar
Ultima committed
37
    FOUND['MPI'] = False
ultimanet's avatar
ultimanet committed
38
39
40

try:
    import pyfftw
Ultima's avatar
Ultima committed
41
    FOUND['pyfftw'] = True
ultimanet's avatar
ultimanet committed
42
except(ImportError):       
Ultima's avatar
Ultima committed
43
    FOUND['pyfftw'] = False
ultimanet's avatar
ultimanet committed
44
45

try:
46
    import h5py
Ultima's avatar
Ultima committed
47
48
    FOUND['h5py'] = True
    FOUND['h5py_parallel'] = h5py.get_config().mpi
ultimanet's avatar
ultimanet committed
49
except(ImportError):
Ultima's avatar
Ultima committed
50
51
    FOUND['h5py'] = False
    FOUND['h5py_parallel'] = False
ultimanet's avatar
ultimanet committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


class distributed_data_object(object):
    """

        NIFTY class for distributed data

        Parameters
        ----------
        global_data : {tuple, list, numpy.ndarray} *at least 1-dimensional*
            Initial data which will be casted to a numpy.ndarray and then 
            stored according to the distribution strategy. The global_data's
            shape overwrites global_shape.
        global_shape : tuple of ints, *optional*
            If no global_data is supplied, global_shape can be used to
            initialize an empty distributed_data_object
        dtype : type, *optional*
            If an explicit dtype is supplied, the given global_data will be 
            casted to it.            
        distribution_strategy : {'fftw' (default), 'not'}, *optional*
            Specifies the way, how global_data will be distributed to the 
            individual nodes. 
            'fftw' follows the distribution strategy of pyfftw.
            'not' does not distribute the data at all. 
            

        Attributes
        ----------
        data : numpy.ndarray
            The numpy.ndarray in which the individual node's data is stored.
        dtype : type
            Data type of the data object.
        distribution_strategy : string
            Name of the used distribution_strategy
        distributor : distributor
            The distributor object which takes care of all distribution and 
            consolidation of the data. 
        shape : tuple of int
            The global shape of the data
            
        Raises
        ------
        TypeError : 
            If the supplied distribution strategy is not known. 
        
    """
98
    def __init__(self, global_data = None, global_shape=None, dtype=None, 
Ultima's avatar
Ultima committed
99
                 local_data=None, local_shape=None,
100
101
                 distribution_strategy='fftw', hermitian=False,
                 alias=None, path=None, comm = MPI.COMM_WORLD, 
102
                 copy = True, *args, **kwargs):
Ultima's avatar
Ultima committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#        
#        ## a given hdf5 file overwrites the other parameters
#        if FOUND['h5py'] == True and alias is not None:
#            ## set file path            
#            file_path = path if (path is not None) else alias 
#            ## open hdf5 file
#            if FOUND['h5py_parallel'] == True and FOUND['MPI'] == True:
#                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
#            else:
#                f= h5py.File(file_path, 'r')        
#            ## open alias in file
#            dset = f[alias] 
#            ## set shape 
#            global_shape = dset.shape
#            ## set dtype
#            dtype = dset.dtype.type

#        ## if no hdf5 path was given, extract global_shape and dtype from 
#        ## the remaining arguments
#        else:        
#            ## an explicitly given dtype overwrites the one from global_data
#            if dtype is None:
#                if global_data is None and local_data is None:
#                    raise ValueError(about._errors.cstring(
#            "ERROR: Neither global_data nor local_data nor dtype supplied!"))      
#                elif global_data is not None:
#                    try:
#                        dtype = global_data.dtype.type
#                    except(AttributeError):
#                        try:
#                            dtype = global_data.dtype
#                        except(AttributeError):
#                            dtype = np.array(global_data).dtype.type
#                elif local_data is not None:
#                    try:
#                        dtype = local_data.dtype.type
#                    except(AttributeError):
#                        try:
#                            dtype = local_data.dtype
#                        except(AttributeError):
#                            dtype = np.array(local_data).dtype.type
#            else:
#                dtype = np.dtype(dtype).type
#            
#            ## an explicitly given global_shape argument is only used if 
#            ## 1. no global_data was supplied, or 
#            ## 2. global_data is a scalar/list of dimension 0.
#            
#            if global_data is not None and np.isscalar(global_data) == False:
#                global_shape = global_data.shape
#            elif global_shape is not None:
#                global_shape = tuple(global_shape)
#                
#            if local_data is not None
#            
##            if global_shape is None:
##                if global_data is None or np.isscalar(global_data):
##                    raise ValueError(about._errors.cstring(
##    "ERROR: Neither non-0-dimensional global_data nor global_shape supplied!"))      
##                global_shape = global_data.shape
##            else:
##                if global_data is None or np.isscalar(global_data):
##                    global_shape = tuple(global_shape)
##                else:
##                    global_shape = global_data.shape
        
        ## TODO: allow init with empty shape
170
171
        self.distributor = distributor_factory.get_distributor(
                                distribution_strategy = distribution_strategy,
Ultima's avatar
Ultima committed
172
                                global_data = global_data,                                
173
                                global_shape = global_shape,
Ultima's avatar
Ultima committed
174
175
176
177
                                local_data = local_data,
                                local_shape = local_shape,
                                alias = alias,
                                path = path,
178
                                dtype = dtype,
Ultima's avatar
Ultima committed
179
                                comm = comm,
180
181
                                **kwargs)
                                
ultimanet's avatar
ultimanet committed
182
183
184
185
        self.distribution_strategy = distribution_strategy
        self.dtype = self.distributor.dtype
        self.shape = self.distributor.global_shape
        
186
187
        self.init_args = args 
        self.init_kwargs = kwargs
188

Ultima's avatar
Ultima committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        (self.data, self.hermitian) = self.distributor.initialize_data(
                                                     global_data = global_data,
                                                     local_data = local_data,
                                                     alias = alias,
                                                     path = alias,
                                                     hermitian = hermitian,
                                                     copy = copy)
#        ## If a hdf5 path was given, load the data
#        if FOUND['h5py'] == True and alias is not None:
#            self.load(alias = alias, path = path)
#            ## close the file handle
#            f.close()
#            
#        ## If the input data was a scalar, set the whole array to this value
#        elif global_data is not None and np.isscalar(global_data):
#            temp = np.empty(self.distributor.local_shape, dtype = self.dtype)
#            temp.fill(global_data)
#            self.set_local_data(temp)
#            self.hermitian = True
#        else:
#            self.set_full_data(data=global_data, hermitian=hermitian, 
#                               copy = copy, **kwargs)
#            
Ultimanet's avatar
Ultimanet committed
212
213
214
215
216
217
218
219
    def copy(self, dtype=None, distribution_strategy=None, **kwargs):
        temp_d2o = self.copy_empty(dtype=dtype, 
                                   distribution_strategy=distribution_strategy, 
                                   **kwargs)     
        if distribution_strategy == None or \
            distribution_strategy == self.distribution_strategy:
            temp_d2o.set_local_data(self.get_local_data(), copy=True)
        else:
220
            #temp_d2o.set_full_data(self.get_full_data())
Ultima's avatar
Ultima committed
221
            temp_d2o.inject((slice(None),), self, (slice(None),))
222
        temp_d2o.hermitian = self.hermitian
223
224
        return temp_d2o
    
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    def copy_empty(self, global_shape=None, dtype=None, 
                   distribution_strategy=None, **kwargs):
        if global_shape == None:
            global_shape = self.shape
        if dtype == None:
            dtype = self.dtype
        if distribution_strategy == None:
            distribution_strategy = self.distribution_strategy

        kwargs.update(self.init_kwargs)
        
        temp_d2o = distributed_data_object(global_shape=global_shape,
                                           dtype=dtype,
                                           distribution_strategy=distribution_strategy,
239
                                           *self.init_args,
240
                                           **kwargs)
241
242
        return temp_d2o
    
243
    def apply_scalar_function(self, function, inplace=False, dtype=None):
244
245
        remember_hermitianQ = self.hermitian
        
Ultimanet's avatar
Ultimanet committed
246
247
        if inplace == True:        
            temp = self
248
249
250
251
            if dtype != None and self.dtype != dtype:
                about.warnings.cprint(\
            "WARNING: Inplace dtype conversion is not possible!")
                
Ultimanet's avatar
Ultimanet committed
252
        else:
253
            temp = self.copy_empty(dtype=dtype)
Ultimanet's avatar
Ultimanet committed
254
255
256
257
258

        try: 
            temp.data[:] = function(self.data)
        except:
            temp.data[:] = np.vectorize(function)(self.data)
259
        
260
261
262
263
        if function in (np.exp, np.log):
            temp.hermitian = remember_hermitianQ
        else:
            temp.hermitian = False
Ultimanet's avatar
Ultimanet committed
264
265
266
267
268
269
        return temp
    
    def apply_generator(self, generator):
        self.set_local_data(generator(self.distributor.local_shape))
        self.hermitian = False
            
ultimanet's avatar
ultimanet committed
270
271
272
273
274
275
    def __str__(self):
        return self.data.__str__()
    
    def __repr__(self):
        return '<distributed_data_object>\n'+self.data.__repr__()
    
276
277
    
    def _compare_helper(self, other, op):
278
        result = self.copy_empty(dtype = np.bool_)
Ultimanet's avatar
Ultimanet committed
279
280
281
        ## Case 1: 'other' is a scalar
        ## -> make point-wise comparison
        if np.isscalar(other):
282
283
            result.set_local_data(
                    getattr(self.get_local_data(copy = False), op)(other))
Ultimanet's avatar
Ultimanet committed
284
285
286
287
288
289
290
            return result        

        ## Case 2: 'other' is a numpy array or a distributed_data_object
        ## -> extract the local data and make point-wise comparison
        elif isinstance(other, np.ndarray) or\
        isinstance(other, distributed_data_object):
            temp_data = self.distributor.extract_local_data(other)
291
292
            result.set_local_data(
                getattr(self.get_local_data(copy=False), op)(temp_data))
Ultimanet's avatar
Ultimanet committed
293
294
295
296
297
298
299
            return result
        
        ## Case 3: 'other' is None
        elif other == None:
            return False
        
        ## Case 4: 'other' is something different
300
        ## -> make a numpy casting and make a recursive call
Ultimanet's avatar
Ultimanet committed
301
302
        else:
            temp_other = np.array(other)
303
            return getattr(self, op)(temp_other)
Ultimanet's avatar
Ultimanet committed
304
        
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    def __ne__(self, other):
        return self._compare_helper(other, '__ne__')
        
    def __lt__(self, other):
        return self._compare_helper(other, '__lt__')
            
    def __le__(self, other):
        return self._compare_helper(other, '__le__')

    def __eq__(self, other):

        return self._compare_helper(other, '__eq__')
    def __ge__(self, other):
        return self._compare_helper(other, '__ge__')

    def __gt__(self, other):
        return self._compare_helper(other, '__gt__')

Ultimanet's avatar
Ultimanet committed
324
    def equal(self, other):
Ultimanet's avatar
Ultimanet committed
325
326
327
328
329
330
331
332
333
        if other is None:
            return False
        try:
            assert(self.dtype == other.dtype)
            assert(self.shape == other.shape)
            assert(self.init_args == other.init_args)
            assert(self.init_kwargs == other.init_kwargs)
            assert(self.distribution_strategy == other.distribution_strategy)
            assert(np.all(self.data == other.data))
Ultimanet's avatar
Ultimanet committed
334
        except(AssertionError, AttributeError):
Ultimanet's avatar
Ultimanet committed
335
336
337
338
339
340
341
            return False
        else:
            return True
        

            
    
342
    def __pos__(self):
343
        temp_d2o = self.copy_empty()
344
        temp_d2o.set_local_data(data = self.get_local_data(), copy = True)
345
346
        return temp_d2o
        
ultimanet's avatar
ultimanet committed
347
    def __neg__(self):
348
        temp_d2o = self.copy_empty()
349
350
        temp_d2o.set_local_data(data = self.get_local_data().__neg__(),
                                copy = True) 
ultimanet's avatar
ultimanet committed
351
352
        return temp_d2o
    
353
    def __abs__(self):
Ultimanet's avatar
Ultimanet committed
354
355
356
357
358
359
360
361
362
363
364
365
        ## translate complex dtypes
        if self.dtype == np.complex64:
            new_dtype = np.float32
        elif self.dtype == np.complex128:
            new_dtype = np.float64
        elif self.dtype == np.complex:
            new_dtype = np.float
        elif issubclass(self.dtype, np.complexfloating):
            new_dtype = np.float
        else:
            new_dtype = self.dtype
        temp_d2o = self.copy_empty(dtype = new_dtype)
366
367
        temp_d2o.set_local_data(data = self.get_local_data().__abs__(),
                                copy = True) 
368
        return temp_d2o
ultimanet's avatar
ultimanet committed
369
            
370
    def __builtin_helper__(self, operator, other, inplace=False):
Ultimanet's avatar
Ultimanet committed
371
372
373
374
375
        ## Case 1: other is not a scalar
        if not (np.isscalar(other) or np.shape(other) == (1,)):
##            if self.shape != other.shape:            
##                raise AttributeError(about._errors.cstring(
##                    "ERROR: Shapes do not match!")) 
376
            try:            
377
                hermitian_Q = (other.hermitian and self.hermitian)
378
379
            except(AttributeError):
                hermitian_Q = False
Ultimanet's avatar
Ultimanet committed
380
381
382
            ## extract the local data from the 'other' object
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
Ultimanet's avatar
Ultimanet committed
383
            
384
385
386
387
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other) or (self.dtype not in (np.complex, np.complex128,
                                                np.complex256)):
            hermitian_Q = self.hermitian
ultimanet's avatar
ultimanet committed
388
            temp_data = operator(other)
389
390
391
392
        ## Case 3: other is complex
        else:
            hermitian_Q = False
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
393
        ## write the new data into a new distributed_data_object        
394
395
396
        if inplace == True:
            temp_d2o = self
        else:
397
398
399
400
401
            ## use common datatype for self and other
            new_dtype = np.dtype(np.find_common_type((self.dtype,),
                                                     (temp_data.dtype,))).type
            temp_d2o = self.copy_empty(
                            dtype = new_dtype)
ultimanet's avatar
ultimanet committed
402
        temp_d2o.set_local_data(data=temp_data)
403
        temp_d2o.hermitian = hermitian_Q
ultimanet's avatar
ultimanet committed
404
        return temp_d2o
405
    """
Ultimanet's avatar
Ultimanet committed
406
    def __inplace_builtin_helper__(self, operator, other):
407
        ## Case 1: other is not a scalar
Ultimanet's avatar
Ultimanet committed
408
409
410
        if not (np.isscalar(other) or np.shape(other) == (1,)):        
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
411
412
413
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other):
            hermitian_Q = self.hermitian
Ultimanet's avatar
Ultimanet committed
414
            temp_data = operator(other)
415
416
417
        ## Case 3: other is complex
        else:
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
418
        self.set_local_data(data=temp_data)
419
        self.hermitian = hermitian_Q
Ultimanet's avatar
Ultimanet committed
420
        return self
421
    """ 
Ultimanet's avatar
Ultimanet committed
422
    
ultimanet's avatar
ultimanet committed
423
424
425
426
427
    def __add__(self, other):
        return self.__builtin_helper__(self.get_local_data().__add__, other)

    def __radd__(self, other):
        return self.__builtin_helper__(self.get_local_data().__radd__, other)
Ultimanet's avatar
Ultimanet committed
428
429

    def __iadd__(self, other):
430
431
432
        return self.__builtin_helper__(self.get_local_data().__iadd__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
433

ultimanet's avatar
ultimanet committed
434
435
436
437
438
439
440
    def __sub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__sub__, other)
    
    def __rsub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rsub__, other)
    
    def __isub__(self, other):
441
442
443
        return self.__builtin_helper__(self.get_local_data().__isub__, 
                                               other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
444
445
446
447
        
    def __div__(self, other):
        return self.__builtin_helper__(self.get_local_data().__div__, other)
    
448
449
450
    def __truediv__(self, other):
        return self.__div__(other)
        
ultimanet's avatar
ultimanet committed
451
452
    def __rdiv__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rdiv__, other)
453
454
455
    
    def __rtruediv__(self, other):
        return self.__rdiv__(other)
ultimanet's avatar
ultimanet committed
456

Ultimanet's avatar
Ultimanet committed
457
    def __idiv__(self, other):
458
459
460
        return self.__builtin_helper__(self.get_local_data().__idiv__, 
                                               other,
                                               inplace = True)
461
    def __itruediv__(self, other):
462
463
        return self.__idiv__(other)
                                               
ultimanet's avatar
ultimanet committed
464
    def __floordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
465
466
        return self.__builtin_helper__(self.get_local_data().__floordiv__, 
                                       other)    
ultimanet's avatar
ultimanet committed
467
    def __rfloordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
468
469
470
        return self.__builtin_helper__(self.get_local_data().__rfloordiv__, 
                                       other)
    def __ifloordiv__(self, other):
471
472
473
        return self.__builtin_helper__(
                    self.get_local_data().__ifloordiv__, other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
474
475
476
477
478
479
480
481
    
    def __mul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__mul__, other)
    
    def __rmul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rmul__, other)

    def __imul__(self, other):
482
483
484
        return self.__builtin_helper__(self.get_local_data().__imul__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
485

ultimanet's avatar
ultimanet committed
486
487
488
489
490
491
492
    def __pow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__pow__, other)
 
    def __rpow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rpow__, other)

    def __ipow__(self, other):
493
        return self.__builtin_helper__(self.get_local_data().__ipow__, 
494
495
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
496
   
497
498
    def __len__(self):
        return self.shape[0]
499
    
500
    def get_dim(self):
501
502
        return np.prod(self.shape)
        
503
    def vdot(self, other):
504
        other = self.distributor.extract_local_data(other)
505
506
507
508
509
        local_vdot = np.vdot(self.get_local_data(), other)
        local_vdot_list = self.distributor._allgather(local_vdot)
        global_vdot = np.sum(local_vdot_list)
        return global_vdot
            
Ultimanet's avatar
Ultimanet committed
510

511
    
ultimanet's avatar
ultimanet committed
512
    def __getitem__(self, key):
Ultima's avatar
Ultima committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        return self.get_data(key)
        
#        ## Case 1: key is a boolean array.
#        ## -> take the local data portion from key, use this for data 
#        ## extraction, and then merge the result in a flat numpy array
#        if isinstance(key, np.ndarray):
#            found = 'ndarray'
#            found_boolean = (key.dtype.type == np.bool_)
#        elif isinstance(key, distributed_data_object):
#            found = 'd2o'
#            found_boolean = (key.dtype == np.bool_)
#        else:
#            found = 'other'
#        ## TODO: transfer this into distributor:
#        if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
#            ## extract the data of local relevance
#            local_bool_array = self.distributor.extract_local_data(key)
#            local_results = self.get_local_data(copy=False)[local_bool_array]
#            global_results = self.distributor._allgather(local_results)
#            global_results = np.concatenate(global_results)
#            return global_results            
#            
#        else:
#            return self.get_data(key)
ultimanet's avatar
ultimanet committed
537
538
539
540
    
    def __setitem__(self, key, data):
        self.set_data(data, key)
        
541
    def _contraction_helper(self, function, **kwargs):
542
543
544
545
546
547
        local = function(self.data, **kwargs)
        local_list = self.distributor._allgather(local)
        global_ = function(local_list, axis=0)
        return global_
        
    def amin(self, **kwargs):
548
        return self._contraction_helper(np.amin, **kwargs)
549
550

    def nanmin(self, **kwargs):
551
        return self._contraction_helper(np.nanmin, **kwargs)
552
553
        
    def amax(self, **kwargs):
554
        return self._contraction_helper(np.amax, **kwargs)
555
556
    
    def nanmax(self, **kwargs):
557
        return self._contraction_helper(np.nanmax, **kwargs)
Ultimanet's avatar
Ultimanet committed
558
    
559
560
561
562
563
564
    def sum(self, **kwargs):
        return self._contraction_helper(np.sum, **kwargs)

    def prod(self, **kwargs):
        return self._contraction_helper(np.prod, **kwargs)        
        
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    def mean(self, power=1):
        ## compute the local means and the weights for the mean-mean. 
        local_mean = np.mean(self.data**power)
        local_weight = np.prod(self.data.shape)
        ## collect the local means and cast the result to a ndarray
        local_mean_weight_list = self.distributor._allgather((local_mean, 
                                                              local_weight))
        local_mean_weight_list =np.array(local_mean_weight_list)   
        ## compute the denominator for the weighted mean-mean                                                           
        global_weight = np.sum(local_mean_weight_list[:,1])
        ## compute the numerator
        numerator = np.sum(local_mean_weight_list[:,0]*\
            local_mean_weight_list[:,1])
        global_mean = numerator/global_weight
        return global_mean

    def var(self):
        mean_of_the_square = self.mean(power=2)
        square_of_the_mean = self.mean()**2
        return mean_of_the_square - square_of_the_mean
    
    def std(self):
        return np.sqrt(self.var())
        
589
590
591
592
593
594
595
596
597
598
599
#    def _argmin_argmax_flat_helper(self, function):
#        local_argmin = function(self.data)
#        local_argmin_value = self.data[np.unravel_index(local_argmin, 
#                                                        self.data.shape)]
#        globalized_local_argmin = self.distributor.globalize_flat_index(local_argmin)                                                       
#        local_argmin_list = self.distributor._allgather((local_argmin_value, 
#                                                         globalized_local_argmin))
#        local_argmin_list = np.array(local_argmin_list, dtype=[('value', int),
#                                                               ('index', int)])    
#        return local_argmin_list
#        
600
601
602
603
    def argmin_flat(self):
        local_argmin = np.argmin(self.data)
        local_argmin_value = self.data[np.unravel_index(local_argmin, 
                                                        self.data.shape)]
604
605
        globalized_local_argmin = self.distributor.globalize_flat_index(
                                                                local_argmin)                                                       
606
        local_argmin_list = self.distributor._allgather((local_argmin_value, 
607
608
609
610
611
612
                                                    globalized_local_argmin))
        local_argmin_list = np.array(local_argmin_list, dtype=[
                                        ('value', local_argmin_value.dtype),
                                        ('index', int)])    
        local_argmin_list = np.sort(local_argmin_list, 
                                    order=['value', 'index'])        
613
614
615
616
617
618
        return local_argmin_list[0][1]
    
    def argmax_flat(self):
        local_argmax = np.argmax(self.data)
        local_argmax_value = -self.data[np.unravel_index(local_argmax, 
                                                        self.data.shape)]
619
620
        globalized_local_argmax = self.distributor.globalize_flat_index(
                                                                local_argmax)                                                       
621
        local_argmax_list = self.distributor._allgather((local_argmax_value, 
622
623
624
625
626
627
                                                    globalized_local_argmax))
        local_argmax_list = np.array(local_argmax_list, dtype=[
                                        ('value', local_argmax_value.dtype),
                                        ('index', int)]) 
        local_argmax_list = np.sort(local_argmax_list, 
                                    order=['value', 'index'])        
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        return local_argmax_list[0][1]
        

    def argmin(self):    
        return np.unravel_index(self.argmin_flat(), self.shape)
    
    def argmax(self):
        return np.unravel_index(self.argmax_flat(), self.shape)
    
    def conjugate(self):
        temp_d2o = self.copy_empty()
        temp_data = np.conj(self.get_local_data())
        temp_d2o.set_local_data(temp_data)
        return temp_d2o

    
    def conj(self):
        return self.conjugate()      
        
    def median(self):
Ultimanet's avatar
Ultimanet committed
648
        about.warnings.cprint(\
649
650
651
652
            "WARNING: The current implementation of median is very expensive!")
        median = np.median(self.get_full_data())
        return median
        
653
    def iscomplex(self):
654
        temp_d2o = self.copy_empty(dtype=np.bool_)
655
656
657
658
        temp_d2o.set_local_data(np.iscomplex(self.data))
        return temp_d2o
    
    def isreal(self):
659
        temp_d2o = self.copy_empty(dtype=np.bool_)
660
661
662
        temp_d2o.set_local_data(np.isreal(self.data))
        return temp_d2o
    
663

664
665
666
667
668
669
670
671
    def all(self):
        local_all = np.all(self.get_local_data())
        global_all = self.distributor._allgather(local_all)
        return np.all(global_all)

    def any(self):
        local_any = np.any(self.get_local_data())
        global_any = self.distributor._allgather(local_any)
672
        return np.any(global_any)
673
        
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
    def unique(self):
        local_unique = np.unique(self.get_local_data())
        global_unique = self.distributor._allgather(local_unique)
        global_unique = np.concatenate(global_unique)
        return np.unique(global_unique)
        
    def bincount(self, weights = None, minlength = None):
        if np.dtype(self.dtype).type not in [np.int8, np.int16, np.int32, 
                np.int64, np.uint8, np.uint16, np.uint32, np.uint64]:
            raise TypeError(about._errors.cstring(
                "ERROR: Distributed-data-object must be of integer datatype!"))                                                
                
        minlength = max(self.amax()+1, minlength)
        
        if weights is not None:
            local_weights = self.distributor.extract_local_data(weights).\
                                                                    flatten()
        else:
            local_weights = None
            
        local_counts = np.bincount(self.get_local_data().flatten(),
                                  weights = local_weights,
                                  minlength = minlength)
        list_of_counts = self.distributor._allgather(local_counts)
        counts = np.sum(list_of_counts, axis = 0)
        return counts
                              
701
    
702
    def set_local_data(self, data, hermitian=False, copy=True):
ultimanet's avatar
ultimanet committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
        """
            Stores data directly in the local data attribute. No distribution 
            is done. The shape of the data must fit the local data attributes
            shape.

            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be stored in the local data attribute.
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
718
719
        self.hermitian = hermitian
        self.data = np.array(data, dtype=self.dtype, copy=copy, order='C')
ultimanet's avatar
ultimanet committed
720
    
Ultima's avatar
Ultima committed
721
    def set_data(self, data, key, hermitian=False, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        """
            Stores the supplied data in the region which is specified by key. 
            The data is distributed according to the distribution strategy. If
            the individual nodes get different key-arguments. Their data is 
            processed one-by-one.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be stored in.                
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
741
        self.hermitian = hermitian
Ultima's avatar
Ultima committed
742
743
744
745
746
747
748
749
750
751
752
753
        self.distributor.disperse_data(data = self.data,
                                       to_key = key,
                                       data_update = data,
                                       copy = copy,
                                       **kwargs)
#                                       
#        (slices, sliceified) = self.__sliceify__(key)        
#        self.distributor.disperse_data(data=self.data, 
#                        to_slices = slices,
#                        data_update = self.__enfold__(data, sliceified),
#                        copy = copy,
#                        *args, **kwargs)        
ultimanet's avatar
ultimanet committed
754
    
755
    def set_full_data(self, data, hermitian=False, copy = True, **kwargs):
ultimanet's avatar
ultimanet committed
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
        """
            Distributes the supplied data to the nodes. The shape of data must 
            match the shape of the distributed_data_object.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            
            Notes
            -----
            set_full_data(foo) is equivalent to set_data(foo,slice(None)) but 
            faster.
        
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
775
        self.hermitian = hermitian
776
777
        self.data = self.distributor.distribute_data(data=data, copy = copy, 
                                                     **kwargs)
ultimanet's avatar
ultimanet committed
778

Ultimanet's avatar
Ultimanet committed
779
    def get_local_data(self, key=(slice(None),), copy=True):
ultimanet's avatar
ultimanet committed
780
781
782
783
784
785
786
787
788
789
790
791
792
        """
            Loads data directly from the local data attribute. No consolidation 
            is done. 

            Parameters
            ----------
            key : int, slice, tuple of int or slice
                The key which will be used to access the data. 
            
            Returns
            -------
            self.data[key] : numpy.ndarray
        
Ultimanet's avatar
Ultimanet committed
793
        """
Ultimanet's avatar
Ultimanet committed
794
795
796
797
        if copy == True:
            return self.data[key]        
        if copy == False:
            return self.data
ultimanet's avatar
ultimanet committed
798
        
799
    def get_data(self, key, **kwargs):
ultimanet's avatar
ultimanet committed
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
        """
            Loads data from the region which is specified by key. The data is 
            consolidated according to the distribution strategy. If the 
            individual nodes get different key-arguments, they get individual
            data. 
            
            Parameters
            ----------
        
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be loaded from.                 
            
            Returns
            -------
            global_data[key] : numpy.ndarray
        
        """
Ultima's avatar
Ultima committed
818
819
820
821
        return self.distributor.collect_data(self.data, key, **kwargs)
#        (slices, sliceified) = self.__sliceify__(key)
#        result = self.distributor.collect_data(self.data, slices, **kwargs)        
#        return self.__defold__(result, sliceified)
ultimanet's avatar
ultimanet committed
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
        
    
    
    def get_full_data(self, target_rank='all'):
        """
            Fully consolidates the distributed data. 
            
            Parameters
            ----------
            target_rank : 'all' (default), int *optional*
                If only one node should recieve the full data, it can be 
                specified here.
            
            Notes
            -----
            get_full_data() is equivalent to get_data(slice(None)) but 
            faster.
        
            Returns
            -------
            None
        """

845
846
        return self.distributor.consolidate_data(self.data, 
                                                 target_rank = target_rank)
ultimanet's avatar
ultimanet committed
847

Ultimanet's avatar
Ultimanet committed
848
849
850
851
852
853
854
    def inject(self, to_slices=(slice(None),), data=None, 
               from_slices=(slice(None),)):
        if data == None:
            return self
        
        self.distributor.inject(self.data, to_slices, data, from_slices)
        
855
856
857
858
859
860
861
862
863
864
865
    def flatten(self, inplace = False):
        flat_shape = (np.prod(self.shape),)
        temp_d2o = self.copy_empty(global_shape = flat_shape)
        flat_data = self.distributor.flatten(self.data, inplace = inplace)
        temp_d2o.set_local_data(data = flat_data)
        if inplace == True:
            self = temp_d2o
            return self
        else:
            return temp_d2o
        
Ultimanet's avatar
Ultimanet committed
866
        
867

ultimanet's avatar
ultimanet committed
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
      
    def save(self, alias, path=None, overwriteQ=True):
        
        """
            Saves a distributed_data_object to disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name for the dataset which is saved within the hdf5 file.
         
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
            
            overwriteQ : Boolean *optional*
                Specifies whether a dataset may be overwritten if it is already
                present in the given hdf5 file or not.
        """
        self.distributor.save_data(self.data, alias, path, overwriteQ)

    def load(self, alias, path=None):
        """
            Loads a distributed_data_object from disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name of the dataset which is loaded from the hdf5 file.
 
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
        """
        self.data = self.distributor.load_data(alias, path)
           

    
906
907
908
909
class _distributor_factory(object):
    def __init__(self):
        self.distributor_store = {}
    
Ultima's avatar
Ultima committed
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
#    def parse_kwargs(self, strategy = None, kwargs = {}):
#        return_dict = {}
#        if strategy == 'not':
#            pass
#        ## These strategies use MPI and therefore accept a MPI.comm object
#        if strategy == 'fftw' or strategy == 'equal' or strategy == 'freeform':
#            if kwargs.has_key('comm'):
#                return_dict['comm'] = kwargs['comm']
#
#        return return_dict
    
    def parse_kwargs(self, distribution_strategy, 
                   global_data = None, global_shape = None,
                   local_data = None, local_shape = None,
                   alias = None, path = None,
                   dtype = None, comm = None, **kwargs):

927
        return_dict = {}
Ultima's avatar
Ultima committed
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020

        ## Check for an hdf5 file and open it if given
        if FOUND['h5py'] == True and alias is not None:
            ## set file path            
            file_path = path if (path is not None) else alias 
            ## open hdf5 file
            if FOUND['h5py_parallel'] == True and FOUND['MPI'] == True:
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
                f = h5py.File(file_path, 'r')   
            ## open alias in file
            dset = f[alias] 
        else:
            dset = None


        ## Parse the MPI communicator        
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            if comm is None:
                raise ValueError(about._errors.cstring(
            "ERROR: The distributor needs a MPI communicator object comm!"))
            else:
                return_dict['comm'] = comm
        
        ## Parse the datatype
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
            (dset is not None):
            dtype = dset.dtype.type
        
        elif distribution_strategy in ['not', 'equal', 'fftw', 'freeform']: 
            if dtype is None:
                if global_data is None and local_data is None:
                    raise ValueError(about._errors.cstring(
            "ERROR: Neither global_data nor local_data nor dtype supplied!"))      
                elif global_data is not None:
                    try:
                        dtype = global_data.dtype.type
                    except(AttributeError):
                        try:
                            dtype = global_data.dtype
                        except(AttributeError):
                            dtype = np.array(global_data).dtype.type
                elif local_data is not None:
                    try:
                        dtype = local_data.dtype.type
                    except(AttributeError):
                        try:
                            dtype = local_data.dtype
                        except(AttributeError):
                            dtype = np.array(local_data).dtype.type
            else:
                dtype = np.dtype(dtype).type                
        return_dict['dtype'] = dtype

        ## Parse the shape
        ## Case 1: global-type slicer
        if distribution_strategy in ['not', 'equal', 'fftw']:       
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
                    "global_shape nor hdf5 file supplied!"))      
            if global_shape == ():
                raise ValueError(about._errors.cstring(
                    "ERROR: global_shape == () is not valid shape!"))
            if np.any(np.array(global_shape) == 0):
                raise ValueError(about._errors.cstring(
                    "ERROR: Dimension of size 0 occurred!"))
            
            return_dict['global_shape'] = global_shape

        ## Case 2: local-type slicer
        elif distribution_strategy in ['freeform']:        
            if local_data is not None and np.isscalar(local_data) == False:
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
                    "local_shape supplied!"))      
            return_dict['local_shape'] = local_shape

            
        ## close the file-handle
        if dset is not None:
            f.close()

1021
        return return_dict
Ultima's avatar
Ultima committed
1022
1023
1024
            
            
    def hash_arguments(self, distribution_strategy, **kwargs):
1025
1026
1027
        kwargs = kwargs.copy()
        if kwargs.has_key('comm'):
            kwargs['comm'] = id(kwargs['comm'])
Ultima's avatar
Ultima committed
1028
1029
1030
1031
1032
1033
1034
        
        if kwargs.has_key('global_shape'):
            kwargs['global_shape'] = kwargs['global_shape']
        if kwargs.has_key('local_shape'):
            kwargs['local_shape'] = kwargs['local_shape']
            
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
1035
        kwargs['distribution_strategy'] = distribution_strategy
1036
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
1037

1038
1039
1040
1041
1042
1043
1044
1045
    def dictionize_np(self, x):
        dic = x.__dict__.items()
        if x is np.float:
            dic[24] = 0 
            dic[29] = 0
            dic[37] = 0
        return frozenset(dic)            
            
Ultima's avatar
Ultima committed
1046
    def get_distributor(self, distribution_strategy, **kwargs):
1047
        ## check if the distribution strategy is known
1048
        
Ultima's avatar
Ultima committed
1049
1050
        known_distribution_strategies = ['not', 'equal', 'freeform']
        if FOUND['pyfftw'] == True:
1051
            known_distribution_strategies += ['fftw',]
Ultima's avatar
Ultima committed
1052
        if not distribution_strategy in known_distribution_strategies:
1053
1054
1055
1056
            raise TypeError(about._errors.cstring(
                "ERROR: Unknown distribution strategy supplied."))
                
        ## parse the kwargs
Ultima's avatar
Ultima committed
1057
1058
1059
1060
1061
1062
        parsed_kwargs = self.parse_kwargs(
                                distribution_strategy = distribution_strategy, 
                                **kwargs)
                                
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
1063
1064
        #print hashed_arguments                                               
        ## check if the distributors has already been produced in the past
Ultima's avatar
Ultima committed
1065
1066
        if self.distributor_store.has_key(hashed_kwargs):
            return self.distributor_store[hashed_kwargs]
1067
1068
1069
        else:                                              
            ## produce new distributor
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
1070
1071
                produced_distributor = _not_distributor(**parsed_kwargs)
            
1072
1073
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
Ultima's avatar
Ultima committed
1074
1075
1076
1077
                                                slicer = _equal_slicer,
                                                name = distribution_strategy,
                                                **parsed_kwargs)
                                                
1078
1079
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
Ultima's avatar
Ultima committed
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
                                                slicer = _fftw_slicer,
                                                name = distribution_strategy,
                                                **parsed_kwargs)    
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
                                                slicer = _freeform_slicer,
                                                name = distribution_strategy,
                                                **parsed_kwargs)             
                                                    
            self.distributor_store[hashed_kwargs] = produced_distributor                                             
            return self.distributor_store[hashed_kwargs]
1091
1092
1093
            
            
distributor_factory = _distributor_factory()
ultimanet's avatar
ultimanet committed
1094
        
1095
class _slicing_distributor(object):
Ultima's avatar
Ultima committed
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
        
#        if comm.rank == 0:        
#            if global_shape is None:
#                raise TypeError(about._errors.cstring(
#                    "ERROR: No shape supplied!"))
#            else:
#                self.global_shape = global_shape      
#        else:
#            self.global_shape = None
#            
#        self.global_shape = comm.bcast(self.global_shape, root = 0)
#        self.global_shape = tuple(self.global_shape)
ultimanet's avatar
ultimanet committed
1109
        
Ultima's avatar
Ultima committed
1110
1111
        self.comm = comm
        self.distribution_strategy = name
ultimanet's avatar
ultimanet committed
1112
        if comm.rank == 0:        
1113
1114
1115
                if dtype is None:        
                    raise TypeError(about._errors.cstring(
                    "ERROR: Failed setting datatype! No datatype supplied."))
ultimanet's avatar
ultimanet committed
1116
                else:
1117
                    self.dtype = dtype                    
ultimanet's avatar
ultimanet committed
1118
1119
1120
        else:
            self.dtype=None
        self.dtype = comm.bcast(self.dtype, root=0)
1121

ultimanet's avatar
ultimanet committed
1122
        
1123
        self._my_dtype_converter = _global_dtype_converter
ultimanet's avatar
ultimanet committed
1124
1125
        
        if not self._my_dtype_converter.known_np_Q(self.dtype):
Ultimanet's avatar
Ultimanet committed
1126
            raise TypeError(about._errors.cstring(\
1127
            "ERROR: The datatype "+str(self.dtype)+" is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
1128
1129
1130

        self.mpi_dtype  = self._my_dtype_converter.to_mpi(self.dtype)
        
1131
1132
1133
        #self._local_size = pyfftw.local_size(self.global_shape)
        #self.local_start = self._local_size[2]
        #self.local_end = self.local_start + self._local_size[1]
Ultima's avatar
Ultima committed
1134
1135
        self.slicer = slicer 
        self._local_size = self.slicer(comm = comm, **remaining_parsed_kwargs)
1136
1137
        self.local_start = self._local_size[0]
        self.local_end = self._local_size[1] 
Ultima's avatar
Ultima committed
1138
        self.global_shape = self._local_size[2]
1139
        
ultimanet's avatar
ultimanet committed
1140
1141
1142
1143
        self.local_length = self.local_end-self.local_start        
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
1144
1145
        comm.Allgather([np.array(self.local_dim,dtype=np.int), MPI.INT],\
            [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
1146
1147
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
        
1148
1149
1150
        self.local_slice = np.array([self.local_start, self.local_end,\
            self.local_length, self.local_dim, self.local_dim_offset],\
            dtype=np.int)
ultimanet's avatar
ultimanet committed
1151
1152
1153
        ## collect all local_slices 
        ## [start, stop, length=stop-start, dimension, dimension_offset]
        self.all_local_slices = np.empty((comm.size,5),dtype=np.int)
1154
1155
        comm.Allgather([np.array((self.local_slice,),dtype=np.int), MPI.INT],\
            [self.all_local_slices, MPI.INT])
ultimanet's avatar
ultimanet committed
1156
        
Ultima's avatar
Ultima committed
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    def initialize_data(self, global_data, local_data, alias, path, hermitian, 
                        copy, **kwargs):
        if FOUND['h5py'] == True and alias is not None:
            local_data = self.load_data(alias = alias, path = path)
            return (local_data, hermitian)
            
        if self.distribution_strategy in ['equal', 'fftw']:    
            if np.isscalar(global_data):
                local_data = np.empty(self.local_shape, dtype = self.dtype)
                local_data.fill(global_data)            
                hermitian = True
            else:
                local_data = self.distribute_data(data = global_data,
                                                  copy = copy)
        elif self.distribution_strategy in ['freeform']:
            if np.isscalar(local_data):
                local_data = np.empty(self.local_shape, dtype = self.dtype)
                local_data.fill(global_data)            
                hermitian = True
            elif local_data is None:
                local_data = np.empty(self.local_shape, dtype = self.dtype)
                hermitian = False
            else:
                local_data = np.array(local_data).astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
        else:
            raise TypeError(about._errors.cstring(
                                        "ERROR: Unknown istribution strategy"))
        return (local_data, hermitian)        
ultimanet's avatar
ultimanet committed
1186
        
1187
1188
1189
1190
1191
1192
    def globalize_flat_index(self, index):
        return int(index)+self.local_dim_offset
        
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
1193
            raise TypeError(about._errors.cstring("ERROR: Length\
1194
1195
1196
1197
1198
1199
1200
1201
1202
                of index tuple does not match the array's shape!"))                 
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
        ## ensure that the globalized index list is within the bounds
        global_index_memory = globalized_index
        globalized_index = np.clip(globalized_index, 
                                   -np.array(self.global_shape),
                                    np.array(self.global_shape)-1)
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
1203
            about.warnings.cprint("WARNING: Indices were clipped!")
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
        globalized_index = tuple(globalized_index)
        return globalized_index
    
    def _allgather(self, thing, comm=None):
        if comm == None:
            comm = self.comm            
        gathered_things = comm.allgather(thing)
        return gathered_things
    
    def distribute_data(self, data=None, comm = None, alias=None,
1214
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
1215
1216
1217
1218
1219
        '''
        distribute data checks 
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
1220
1221
        if comm == None:
            comm = self.comm            
1222
1223
        rank = comm.Get_rank()
        size = comm.Get_size()        
1224
        local_data_available_Q = np.array((int(data is not None), ))
1225
        data_available_Q = np.empty(size,dtype=int)
1226
1227
        comm.Allgather([local_data_available_Q, MPI.INT], 
                       [data_available_Q, MPI.INT])        
1228
        
Ultima's avatar
Ultima committed
1229
        if data_available_Q[0]==False and FOUND['h5py']:
ultimanet's avatar
ultimanet committed
1230
1231
            try: 
                file_path = path if path != None else alias 
Ultima's avatar
Ultima committed
1232
                if FOUND['h5py_parallel']:
ultimanet's avatar
ultimanet committed
1233
1234
1235
1236
                    f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
                else:
                    f= h5py.File(file_path, 'r')        
                dset = f[alias]
1237
1238
                if dset.shape == self.global_shape and \
                 dset.dtype.type == self.dtype:
ultimanet's avatar
ultimanet committed
1239
1240
1241
1242
                    temp_data = dset[self.local_start:self.local_end]
                    f.close()
                    return temp_data
                else:
Ultimanet's avatar
Ultimanet committed
1243
                    raise TypeError(about._errors.cstring("ERROR: \
1244
                    Input data has the wrong shape or wrong dtype!"))                 
ultimanet's avatar
ultimanet committed
1245
1246
1247
            except(IOError, AttributeError):
                pass
            
1248
        if np.all(data_available_Q==False):
Ultimanet's avatar
Ultimanet committed
1249
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
ultimanet's avatar
ultimanet committed
1250
        ## if all nodes got data, we assume that it is the right data and 
1251
1252
        ## store it individually. If not, take the data on node 0 and scatter 
        ## it...
ultimanet's avatar
ultimanet committed
1253
        if np.all(data_available_Q):
1254
            return data[self.local_start:self.local_end].astype(self.dtype,\
1255
                copy=copy)    
1256
1257
        ## ... but only if node 0 has actually data!
        elif data_available_Q[0] == False:# or np.all(data_available_Q==False):
Ultimanet's avatar
Ultimanet committed
1258
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
1259
        
ultimanet's avatar
ultimanet committed
1260
1261
        else:
            if data == None:
1262
                data = np.empty(self.global_shape, dtype = self.dtype)            
ultimanet's avatar
ultimanet committed
1263
1264
            if rank == 0:
                if np.all(data.shape != self.global_shape):
Ultimanet's avatar
Ultimanet committed
1265
                    raise TypeError(about._errors.cstring(\
1266
                        "ERROR: Input data has the wrong shape!"))
ultimanet's avatar
ultimanet committed
1267
            ## Scatter the data!            
Ultimanet's avatar
Ultimanet committed
1268
            _scattered_data = np.empty(self.local_shape, dtype = self.dtype)
ultimanet's avatar
ultimanet committed
1269
1270
            _dim_list = self.all_local_slices[:,3]
            _dim_offset_list = self.all_local_slices[:,4]
1271
1272
            comm.Scatterv([data, _dim_list, _dim_offset_list, self.mpi_dtype],\
                [_scattered_data, self.mpi_dtype], root=0)
ultimanet's avatar
ultimanet committed
1273
1274
1275
            return _scattered_data
        return None
    
1276
1277
    

Ultima's avatar
Ultima committed
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    def disperse_data(self, data, to_key, data_update, from_key=None, 
                      comm=None, copy=True, **kwargs):
                          
        return self.disperse_data_from_slices(data = data, 
                                              to_slices = to_key,
                                              data_update = data_update,
                                              from_slices = from_key,
                                              comm = comm,
                                              copy = copy, 
                                              **kwargs)
                                              
    def disperse_data_from_slices(self, data, to_slices, data_update, 
                                  from_slices=None, comm=None, copy = True, 
                                  **kwargs):
        (to_slices, sliceified) = self._sliceify(to_slices)      
        data_update = self._enfold(data_update, sliceified)

1295
1296
        if comm == None:
            comm = self.comm            
1297
1298
1299
1300
1301
        to_slices_list = comm.allgather(to_slices)
        ## check if all slices are the same. 
        if all(x == to_slices_list[0] for x in to_slices_list):
            ## in this case, the _disperse_data_primitive can simply be called 
            ##with target_rank = 'all'
Ultima's avatar
Ultima committed
1302
            self._disperse_data_from_slices_primitive(data = data, 
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
                                          to_slices = to_slices,
                                          data_update=data_update,
                                          from_slices=from_slices, 
                                          source_rank='all', 
                                          comm=comm,
                                          copy = copy)
        ## if the different nodes got different slices, disperse the data 
        ## individually
        else:
            i = 0        
            for temp_to_slices in to_slices_list:
                ## make the collect_data call on all nodes            
Ultima's avatar
Ultima committed
1315
                self._disperse_data_from_slices_primitive(data=data,
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
                                              to_slices=temp_to_slices,
                                              data_update=data_update,
                                              from_slices=from_slices,
                                              source_rank=i, 
                                              comm=comm,
                                              copy = copy)
                i += 1
                 
        
#    def _disperse_data_primitive(self, data, to_slices, data_update, 
#                        from_slices, source_rank='all', comm=None, copy=True):
#        ## compute the part of the to_slice which is relevant for the 
#        ## individual node      
#        localized_to_start, localized_to_stop = self._backshift_and_decycle(
#            to_slices[0], self.local_start, self.local_end,\
#                self.global_shape[0])
#        local_to_slice = (slice(localized_to_start, localized_to_stop,\
#                        to_slices[0].step),) + to_slices[1:]
#                        
#        ## compute the parameter sets and list for the data splitting
#        local_slice_shape = data[local_slice].shape        
#        local_affected_data_length = local_slice_shape[0]
#        local_affected_data_length_list=np.empty(comm.size, dtype=np.int)        
#        comm.Allgather(\
#            [np.array(local_affected_data_length, dtype=np.int), MPI.INT],\
#            [local_affected_data_length_list, MPI.INT])        
#        local_affected_data_length_offset_list = np.append([0],\
#                            np.cumsum(local_affected_data_length_list)[:-1])
#
#    
    
Ultima's avatar
Ultima committed
1347
1348
1349
    def _disperse_data_from_slices_primitive(self, data, to_slices, 
                        data_update, from_slices, source_rank='all', comm=None, 
                        copy=True):
1350
1351
1352
1353
1354
1355
1356
1357
1358
        if comm == None:
            comm = self.comm         
    
#        if to_slices[0].step is not None and to_slices[0].step < -1:
#            raise ValueError(about._errors.cstring(
#                "ERROR: Negative stepsizes other than -1 are not supported!"))

        ## parse the to_slices object
        localized_to_start, localized_to_stop=self._backshift_and_decycle(
Ultimanet's avatar
Ultimanet committed
1359
            to_slices[0], self.local_start, self.local_end,\
1360
                self.global_shape[0])
1361
1362
1363
1364
        local_to_slice = (slice(localized_to_start, localized_to_stop,\
                        to_slices[0].step),) + to_slices[1:]   
        local_to_slice_shape = data[local_to_slice].shape        

ultimanet's avatar
ultimanet committed
1365
        if source_rank == 'all':
1366
        
Ultimanet's avatar
Ultimanet committed
1367
            
1368
            ## parse the from_slices object
Ultimanet's avatar
Ultimanet committed
1369
            if from_slices == None:
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
                from_slices = (slice(None, None, None),) 
            (from_slices_start, from_slices_stop)=self._backshift_and_decycle(
                                        slice_object = from_slices[0],
                                        shifted_start = 0,
                                        shifted_stop = data_update.shape[0],
                                        global_length = data_update.shape[0])
            if from_slices_start == None:
                raise ValueError(about._errors.cstring(\
                        "ERROR: _backshift_and_decycle should never return "+\
                        "None for local_start!"))
1380
                        
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404


            ## parse the step sizes
            from_step = from_slices[0].step
            if from_step == None:
                from_step = 1
            elif from_step == 0:            
                raise ValueError(about._errors.cstring(\
                    "ERROR: from_step size == 0!"))

            to_step = to_slices[0].step
            if to_step == None:
                to_step = 1
            elif to_step == 0:            
                raise ValueError(about._errors.cstring(\
                    "ERROR: to_step size == 0!"))


            
            ## Compute the offset of the data the individual node will take.
            ## The offset is free of stepsizes. It is the offset in terms of 
            ## the purely transported data. If to_step < 0, the offset will
            ## be calculated in reverse order
            order = np.sign(to_step)
1405
            
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415