nifty_mpi_data.py 118 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Theo Steininger
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.


ultimanet's avatar
ultimanet committed
24
25


26

Ultima's avatar
Ultima committed
27
28
##initialize the 'FOUND-packages'-dictionary 
FOUND = {}
ultimanet's avatar
ultimanet committed
29
import numpy as np
Ultimanet's avatar
Ultimanet committed
30
from nifty_about import about
31
from weakref import WeakValueDictionary as weakdict
ultimanet's avatar
ultimanet committed
32
33

try:
34
    from mpi4py import MPI
Ultima's avatar
Ultima committed
35
    FOUND['MPI'] = True
ultimanet's avatar
ultimanet committed
36
except(ImportError): 
37
    import mpi_dummy as MPI
Ultima's avatar
Ultima committed
38
    FOUND['MPI'] = False
ultimanet's avatar
ultimanet committed
39
40
41

try:
    import pyfftw
Ultima's avatar
Ultima committed
42
    FOUND['pyfftw'] = True
ultimanet's avatar
ultimanet committed
43
except(ImportError):       
Ultima's avatar
Ultima committed
44
    FOUND['pyfftw'] = False
ultimanet's avatar
ultimanet committed
45
46

try:
47
    import h5py
Ultima's avatar
Ultima committed
48
49
    FOUND['h5py'] = True
    FOUND['h5py_parallel'] = h5py.get_config().mpi
ultimanet's avatar
ultimanet committed
50
except(ImportError):
Ultima's avatar
Ultima committed
51
52
    FOUND['h5py'] = False
    FOUND['h5py_parallel'] = False
ultimanet's avatar
ultimanet committed
53
54


55
56
57

COMM = MPI.COMM_WORLD

ultimanet's avatar
ultimanet committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class distributed_data_object(object):
    """

        NIFTY class for distributed data

        Parameters
        ----------
        global_data : {tuple, list, numpy.ndarray} *at least 1-dimensional*
            Initial data which will be casted to a numpy.ndarray and then 
            stored according to the distribution strategy. The global_data's
            shape overwrites global_shape.
        global_shape : tuple of ints, *optional*
            If no global_data is supplied, global_shape can be used to
            initialize an empty distributed_data_object
        dtype : type, *optional*
            If an explicit dtype is supplied, the given global_data will be 
            casted to it.            
        distribution_strategy : {'fftw' (default), 'not'}, *optional*
            Specifies the way, how global_data will be distributed to the 
            individual nodes. 
            'fftw' follows the distribution strategy of pyfftw.
            'not' does not distribute the data at all. 
            

        Attributes
        ----------
        data : numpy.ndarray
            The numpy.ndarray in which the individual node's data is stored.
        dtype : type
            Data type of the data object.
        distribution_strategy : string
            Name of the used distribution_strategy
        distributor : distributor
            The distributor object which takes care of all distribution and 
            consolidation of the data. 
        shape : tuple of int
            The global shape of the data
            
        Raises
        ------
        TypeError : 
            If the supplied distribution strategy is not known. 
        
    """
102
    def __init__(self, global_data = None, global_shape=None, dtype=None, 
Ultima's avatar
Ultima committed
103
                 local_data=None, local_shape=None,
104
105
                 distribution_strategy='fftw', hermitian=False,
                 alias=None, path=None, comm = MPI.COMM_WORLD, 
106
                 copy = True, *args, **kwargs):
Ultima's avatar
Ultima committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#        
#        ## a given hdf5 file overwrites the other parameters
#        if FOUND['h5py'] == True and alias is not None:
#            ## set file path            
#            file_path = path if (path is not None) else alias 
#            ## open hdf5 file
#            if FOUND['h5py_parallel'] == True and FOUND['MPI'] == True:
#                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
#            else:
#                f= h5py.File(file_path, 'r')        
#            ## open alias in file
#            dset = f[alias] 
#            ## set shape 
#            global_shape = dset.shape
#            ## set dtype
#            dtype = dset.dtype.type

#        ## if no hdf5 path was given, extract global_shape and dtype from 
#        ## the remaining arguments
#        else:        
#            ## an explicitly given dtype overwrites the one from global_data
#            if dtype is None:
#                if global_data is None and local_data is None:
#                    raise ValueError(about._errors.cstring(
#            "ERROR: Neither global_data nor local_data nor dtype supplied!"))      
#                elif global_data is not None:
#                    try:
#                        dtype = global_data.dtype.type
#                    except(AttributeError):
#                        try:
#                            dtype = global_data.dtype
#                        except(AttributeError):
#                            dtype = np.array(global_data).dtype.type
#                elif local_data is not None:
#                    try:
#                        dtype = local_data.dtype.type
#                    except(AttributeError):
#                        try:
#                            dtype = local_data.dtype
#                        except(AttributeError):
#                            dtype = np.array(local_data).dtype.type
#            else:
#                dtype = np.dtype(dtype).type
#            
#            ## an explicitly given global_shape argument is only used if 
#            ## 1. no global_data was supplied, or 
#            ## 2. global_data is a scalar/list of dimension 0.
#            
#            if global_data is not None and np.isscalar(global_data) == False:
#                global_shape = global_data.shape
#            elif global_shape is not None:
#                global_shape = tuple(global_shape)
#                
#            if local_data is not None
#            
##            if global_shape is None:
##                if global_data is None or np.isscalar(global_data):
##                    raise ValueError(about._errors.cstring(
##    "ERROR: Neither non-0-dimensional global_data nor global_shape supplied!"))      
##                global_shape = global_data.shape
##            else:
##                if global_data is None or np.isscalar(global_data):
##                    global_shape = tuple(global_shape)
##                else:
##                    global_shape = global_data.shape
        
        ## TODO: allow init with empty shape
174
175
176
177
178
179
        
        if isinstance(global_data, tuple) or isinstance(global_data, list):
            global_data = np.array(global_data, copy=False)
        if isinstance(local_data, tuple) or isinstance(local_data, list):
            local_data = np.array(local_data, copy=False)
        
180
181
        self.distributor = distributor_factory.get_distributor(
                                distribution_strategy = distribution_strategy,
182
                                comm = comm,
Ultima's avatar
Ultima committed
183
                                global_data = global_data,                                
184
                                global_shape = global_shape,
Ultima's avatar
Ultima committed
185
186
187
188
                                local_data = local_data,
                                local_shape = local_shape,
                                alias = alias,
                                path = path,
189
190
191
                                dtype = dtype,
                                **kwargs)
                                
ultimanet's avatar
ultimanet committed
192
193
194
195
        self.distribution_strategy = distribution_strategy
        self.dtype = self.distributor.dtype
        self.shape = self.distributor.global_shape
        
196
197
        self.init_args = args 
        self.init_kwargs = kwargs
198

Ultima's avatar
Ultima committed
199
200
201
202
203
204
205
        (self.data, self.hermitian) = self.distributor.initialize_data(
                                                     global_data = global_data,
                                                     local_data = local_data,
                                                     alias = alias,
                                                     path = alias,
                                                     hermitian = hermitian,
                                                     copy = copy)
206
        self.index = d2o_librarian.register(self)
Ultima's avatar
Ultima committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#        ## If a hdf5 path was given, load the data
#        if FOUND['h5py'] == True and alias is not None:
#            self.load(alias = alias, path = path)
#            ## close the file handle
#            f.close()
#            
#        ## If the input data was a scalar, set the whole array to this value
#        elif global_data is not None and np.isscalar(global_data):
#            temp = np.empty(self.distributor.local_shape, dtype = self.dtype)
#            temp.fill(global_data)
#            self.set_local_data(temp)
#            self.hermitian = True
#        else:
#            self.set_full_data(data=global_data, hermitian=hermitian, 
#                               copy = copy, **kwargs)
#            
Ultimanet's avatar
Ultimanet committed
223
224
225
226
227
228
229
230
    def copy(self, dtype=None, distribution_strategy=None, **kwargs):
        temp_d2o = self.copy_empty(dtype=dtype, 
                                   distribution_strategy=distribution_strategy, 
                                   **kwargs)     
        if distribution_strategy == None or \
            distribution_strategy == self.distribution_strategy:
            temp_d2o.set_local_data(self.get_local_data(), copy=True)
        else:
231
            #temp_d2o.set_full_data(self.get_full_data())
Ultima's avatar
Ultima committed
232
            temp_d2o.inject((slice(None),), self, (slice(None),))
233
        temp_d2o.hermitian = self.hermitian
234
235
        return temp_d2o
    
236
    def copy_empty(self, global_shape=None, local_shape=None, dtype=None, 
237
238
239
240
241
242
243
244
245
246
247
                   distribution_strategy=None, **kwargs):
        if global_shape == None:
            global_shape = self.shape
        if dtype == None:
            dtype = self.dtype
        if distribution_strategy == None:
            distribution_strategy = self.distribution_strategy

        kwargs.update(self.init_kwargs)
        
        temp_d2o = distributed_data_object(global_shape=global_shape,
248
249
250
251
252
                               local_shape = local_shape,
                               dtype = dtype,
                               distribution_strategy = distribution_strategy,
                               *self.init_args,
                               **kwargs)
253
254
        return temp_d2o
    
255
    def apply_scalar_function(self, function, inplace=False, dtype=None):
256
257
        remember_hermitianQ = self.hermitian
        
Ultimanet's avatar
Ultimanet committed
258
259
        if inplace == True:        
            temp = self
260
            if dtype != None and self.dtype != np.dtype(dtype):
261
262
263
                about.warnings.cprint(\
            "WARNING: Inplace dtype conversion is not possible!")
                
Ultimanet's avatar
Ultimanet committed
264
        else:
265
            temp = self.copy_empty(dtype=dtype)
Ultimanet's avatar
Ultimanet committed
266
267
268
269
270

        try: 
            temp.data[:] = function(self.data)
        except:
            temp.data[:] = np.vectorize(function)(self.data)
271
        
272
273
274
275
        if function in (np.exp, np.log):
            temp.hermitian = remember_hermitianQ
        else:
            temp.hermitian = False
Ultimanet's avatar
Ultimanet committed
276
277
278
279
280
281
        return temp
    
    def apply_generator(self, generator):
        self.set_local_data(generator(self.distributor.local_shape))
        self.hermitian = False
            
ultimanet's avatar
ultimanet committed
282
283
284
285
286
287
    def __str__(self):
        return self.data.__str__()
    
    def __repr__(self):
        return '<distributed_data_object>\n'+self.data.__repr__()
    
288
289
    
    def _compare_helper(self, other, op):
290
        result = self.copy_empty(dtype = np.bool_)
Ultimanet's avatar
Ultimanet committed
291
292
293
        ## Case 1: 'other' is a scalar
        ## -> make point-wise comparison
        if np.isscalar(other):
294
295
            result.set_local_data(
                    getattr(self.get_local_data(copy = False), op)(other))
Ultimanet's avatar
Ultimanet committed
296
297
298
299
300
301
302
            return result        

        ## Case 2: 'other' is a numpy array or a distributed_data_object
        ## -> extract the local data and make point-wise comparison
        elif isinstance(other, np.ndarray) or\
        isinstance(other, distributed_data_object):
            temp_data = self.distributor.extract_local_data(other)
303
304
            result.set_local_data(
                getattr(self.get_local_data(copy=False), op)(temp_data))
Ultimanet's avatar
Ultimanet committed
305
306
307
308
309
310
311
            return result
        
        ## Case 3: 'other' is None
        elif other == None:
            return False
        
        ## Case 4: 'other' is something different
312
        ## -> make a numpy casting and make a recursive call
Ultimanet's avatar
Ultimanet committed
313
314
        else:
            temp_other = np.array(other)
315
            return getattr(self, op)(temp_other)
Ultimanet's avatar
Ultimanet committed
316
        
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

    def __ne__(self, other):
        return self._compare_helper(other, '__ne__')
        
    def __lt__(self, other):
        return self._compare_helper(other, '__lt__')
            
    def __le__(self, other):
        return self._compare_helper(other, '__le__')

    def __eq__(self, other):

        return self._compare_helper(other, '__eq__')
    def __ge__(self, other):
        return self._compare_helper(other, '__ge__')

    def __gt__(self, other):
        return self._compare_helper(other, '__gt__')

Ultimanet's avatar
Ultimanet committed
336
    def equal(self, other):
Ultimanet's avatar
Ultimanet committed
337
338
339
340
341
342
343
344
345
        if other is None:
            return False
        try:
            assert(self.dtype == other.dtype)
            assert(self.shape == other.shape)
            assert(self.init_args == other.init_args)
            assert(self.init_kwargs == other.init_kwargs)
            assert(self.distribution_strategy == other.distribution_strategy)
            assert(np.all(self.data == other.data))
Ultimanet's avatar
Ultimanet committed
346
        except(AssertionError, AttributeError):
Ultimanet's avatar
Ultimanet committed
347
348
349
350
351
352
353
            return False
        else:
            return True
        

            
    
354
    def __pos__(self):
355
        temp_d2o = self.copy_empty()
356
        temp_d2o.set_local_data(data = self.get_local_data(), copy = True)
357
358
        return temp_d2o
        
ultimanet's avatar
ultimanet committed
359
    def __neg__(self):
360
        temp_d2o = self.copy_empty()
361
362
        temp_d2o.set_local_data(data = self.get_local_data().__neg__(),
                                copy = True) 
ultimanet's avatar
ultimanet committed
363
364
        return temp_d2o
    
365
    def __abs__(self):
Ultimanet's avatar
Ultimanet committed
366
        ## translate complex dtypes
367
368
369
370
371
372
        if self.dtype == np.dtype('complex64'):
            new_dtype = np.dtype('float32')
        elif self.dtype == np.dtype('complex128'):
            new_dtype = np.dtype('float64')
        elif issubclass(self.dtype.type, np.complexfloating):
            new_dtype = np.dtype('float')
Ultimanet's avatar
Ultimanet committed
373
374
375
        else:
            new_dtype = self.dtype
        temp_d2o = self.copy_empty(dtype = new_dtype)
376
377
        temp_d2o.set_local_data(data = self.get_local_data().__abs__(),
                                copy = True) 
378
        return temp_d2o
ultimanet's avatar
ultimanet committed
379
            
380
    def __builtin_helper__(self, operator, other, inplace=False):
Ultimanet's avatar
Ultimanet committed
381
382
383
384
385
        ## Case 1: other is not a scalar
        if not (np.isscalar(other) or np.shape(other) == (1,)):
##            if self.shape != other.shape:            
##                raise AttributeError(about._errors.cstring(
##                    "ERROR: Shapes do not match!")) 
386
            try:            
387
                hermitian_Q = (other.hermitian and self.hermitian)
388
389
            except(AttributeError):
                hermitian_Q = False
Ultimanet's avatar
Ultimanet committed
390
391
392
            ## extract the local data from the 'other' object
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
Ultimanet's avatar
Ultimanet committed
393
            
394
        ## Case 2: other is a real scalar -> preserve hermitianity
395
396
397
        elif np.isreal(other) or (self.dtype not in (
                                                np.dtype('complex128'),
                                                np.dtype('complex256'))):
398
            hermitian_Q = self.hermitian
ultimanet's avatar
ultimanet committed
399
            temp_data = operator(other)
400
401
402
403
        ## Case 3: other is complex
        else:
            hermitian_Q = False
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
404
        ## write the new data into a new distributed_data_object        
405
406
407
        if inplace == True:
            temp_d2o = self
        else:
408
409
            ## use common datatype for self and other
            new_dtype = np.dtype(np.find_common_type((self.dtype,),
410
                                                     (temp_data.dtype,)))
411
412
            temp_d2o = self.copy_empty(
                            dtype = new_dtype)
ultimanet's avatar
ultimanet committed
413
        temp_d2o.set_local_data(data=temp_data)
414
        temp_d2o.hermitian = hermitian_Q
ultimanet's avatar
ultimanet committed
415
        return temp_d2o
416
    """
Ultimanet's avatar
Ultimanet committed
417
    def __inplace_builtin_helper__(self, operator, other):
418
        ## Case 1: other is not a scalar
Ultimanet's avatar
Ultimanet committed
419
420
421
        if not (np.isscalar(other) or np.shape(other) == (1,)):        
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
422
423
424
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other):
            hermitian_Q = self.hermitian
Ultimanet's avatar
Ultimanet committed
425
            temp_data = operator(other)
426
427
428
        ## Case 3: other is complex
        else:
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
429
        self.set_local_data(data=temp_data)
430
        self.hermitian = hermitian_Q
Ultimanet's avatar
Ultimanet committed
431
        return self
432
    """ 
Ultimanet's avatar
Ultimanet committed
433
    
ultimanet's avatar
ultimanet committed
434
435
436
437
438
    def __add__(self, other):
        return self.__builtin_helper__(self.get_local_data().__add__, other)

    def __radd__(self, other):
        return self.__builtin_helper__(self.get_local_data().__radd__, other)
Ultimanet's avatar
Ultimanet committed
439
440

    def __iadd__(self, other):
441
442
443
        return self.__builtin_helper__(self.get_local_data().__iadd__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
444

ultimanet's avatar
ultimanet committed
445
446
447
448
449
450
451
    def __sub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__sub__, other)
    
    def __rsub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rsub__, other)
    
    def __isub__(self, other):
452
453
454
        return self.__builtin_helper__(self.get_local_data().__isub__, 
                                               other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
455
456
457
458
        
    def __div__(self, other):
        return self.__builtin_helper__(self.get_local_data().__div__, other)
    
459
460
461
    def __truediv__(self, other):
        return self.__div__(other)
        
ultimanet's avatar
ultimanet committed
462
463
    def __rdiv__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rdiv__, other)
464
465
466
    
    def __rtruediv__(self, other):
        return self.__rdiv__(other)
ultimanet's avatar
ultimanet committed
467

Ultimanet's avatar
Ultimanet committed
468
    def __idiv__(self, other):
469
470
471
        return self.__builtin_helper__(self.get_local_data().__idiv__, 
                                               other,
                                               inplace = True)
472
    def __itruediv__(self, other):
473
474
        return self.__idiv__(other)
                                               
ultimanet's avatar
ultimanet committed
475
    def __floordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
476
477
        return self.__builtin_helper__(self.get_local_data().__floordiv__, 
                                       other)    
ultimanet's avatar
ultimanet committed
478
    def __rfloordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
479
480
481
        return self.__builtin_helper__(self.get_local_data().__rfloordiv__, 
                                       other)
    def __ifloordiv__(self, other):
482
483
484
        return self.__builtin_helper__(
                    self.get_local_data().__ifloordiv__, other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
485
486
487
488
489
490
491
492
    
    def __mul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__mul__, other)
    
    def __rmul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rmul__, other)

    def __imul__(self, other):
493
494
495
        return self.__builtin_helper__(self.get_local_data().__imul__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
496

ultimanet's avatar
ultimanet committed
497
498
499
500
501
502
503
    def __pow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__pow__, other)
 
    def __rpow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rpow__, other)

    def __ipow__(self, other):
504
        return self.__builtin_helper__(self.get_local_data().__ipow__, 
505
506
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
507
   
508
509
    def __len__(self):
        return self.shape[0]
510
    
511
    def get_dim(self):
512
513
        return np.prod(self.shape)
        
514
    def vdot(self, other):
515
        other = self.distributor.extract_local_data(other)
516
517
518
519
520
        local_vdot = np.vdot(self.get_local_data(), other)
        local_vdot_list = self.distributor._allgather(local_vdot)
        global_vdot = np.sum(local_vdot_list)
        return global_vdot
            
Ultimanet's avatar
Ultimanet committed
521

522
    
ultimanet's avatar
ultimanet committed
523
    def __getitem__(self, key):
Ultima's avatar
Ultima committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        return self.get_data(key)
        
#        ## Case 1: key is a boolean array.
#        ## -> take the local data portion from key, use this for data 
#        ## extraction, and then merge the result in a flat numpy array
#        if isinstance(key, np.ndarray):
#            found = 'ndarray'
#            found_boolean = (key.dtype.type == np.bool_)
#        elif isinstance(key, distributed_data_object):
#            found = 'd2o'
#            found_boolean = (key.dtype == np.bool_)
#        else:
#            found = 'other'
#        ## TODO: transfer this into distributor:
#        if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
#            ## extract the data of local relevance
#            local_bool_array = self.distributor.extract_local_data(key)
#            local_results = self.get_local_data(copy=False)[local_bool_array]
#            global_results = self.distributor._allgather(local_results)
#            global_results = np.concatenate(global_results)
#            return global_results            
#            
#        else:
#            return self.get_data(key)
ultimanet's avatar
ultimanet committed
548
549
550
551
    
    def __setitem__(self, key, data):
        self.set_data(data, key)
        
552
    def _contraction_helper(self, function, **kwargs):
553
554
555
556
557
558
        local = function(self.data, **kwargs)
        local_list = self.distributor._allgather(local)
        global_ = function(local_list, axis=0)
        return global_
        
    def amin(self, **kwargs):
559
        return self._contraction_helper(np.amin, **kwargs)
560
561

    def nanmin(self, **kwargs):
562
        return self._contraction_helper(np.nanmin, **kwargs)
563
564
        
    def amax(self, **kwargs):
565
        return self._contraction_helper(np.amax, **kwargs)
566
567
    
    def nanmax(self, **kwargs):
568
        return self._contraction_helper(np.nanmax, **kwargs)
Ultimanet's avatar
Ultimanet committed
569
    
570
571
572
573
574
575
    def sum(self, **kwargs):
        return self._contraction_helper(np.sum, **kwargs)

    def prod(self, **kwargs):
        return self._contraction_helper(np.prod, **kwargs)        
        
576
577
578
579
580
581
582
583
    def mean(self, power=1):
        ## compute the local means and the weights for the mean-mean. 
        local_mean = np.mean(self.data**power)
        local_weight = np.prod(self.data.shape)
        ## collect the local means and cast the result to a ndarray
        local_mean_weight_list = self.distributor._allgather((local_mean, 
                                                              local_weight))
        local_mean_weight_list =np.array(local_mean_weight_list)   
584
        ## compute the denominator for the weighted mean-mean               
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        global_weight = np.sum(local_mean_weight_list[:,1])
        ## compute the numerator
        numerator = np.sum(local_mean_weight_list[:,0]*\
            local_mean_weight_list[:,1])
        global_mean = numerator/global_weight
        return global_mean

    def var(self):
        mean_of_the_square = self.mean(power=2)
        square_of_the_mean = self.mean()**2
        return mean_of_the_square - square_of_the_mean
    
    def std(self):
        return np.sqrt(self.var())
        
600
601
602
603
#    def _argmin_argmax_flat_helper(self, function):
#        local_argmin = function(self.data)
#        local_argmin_value = self.data[np.unravel_index(local_argmin, 
#                                                        self.data.shape)]
604
#        globalized_local_argmin = self.distributor.globalize_flat_index(local_argmin) 
605
606
607
608
609
610
#        local_argmin_list = self.distributor._allgather((local_argmin_value, 
#                                                         globalized_local_argmin))
#        local_argmin_list = np.array(local_argmin_list, dtype=[('value', int),
#                                                               ('index', int)])    
#        return local_argmin_list
#        
611
612
613
614
    def argmin_flat(self):
        local_argmin = np.argmin(self.data)
        local_argmin_value = self.data[np.unravel_index(local_argmin, 
                                                        self.data.shape)]
615
        globalized_local_argmin = self.distributor.globalize_flat_index(
616
                                                                local_argmin)  
617
        local_argmin_list = self.distributor._allgather((local_argmin_value, 
618
619
620
621
622
623
                                                    globalized_local_argmin))
        local_argmin_list = np.array(local_argmin_list, dtype=[
                                        ('value', local_argmin_value.dtype),
                                        ('index', int)])    
        local_argmin_list = np.sort(local_argmin_list, 
                                    order=['value', 'index'])        
624
625
626
627
628
629
        return local_argmin_list[0][1]
    
    def argmax_flat(self):
        local_argmax = np.argmax(self.data)
        local_argmax_value = -self.data[np.unravel_index(local_argmax, 
                                                        self.data.shape)]
630
        globalized_local_argmax = self.distributor.globalize_flat_index(
631
                                                                local_argmax)
632
        local_argmax_list = self.distributor._allgather((local_argmax_value, 
633
634
635
636
637
638
                                                    globalized_local_argmax))
        local_argmax_list = np.array(local_argmax_list, dtype=[
                                        ('value', local_argmax_value.dtype),
                                        ('index', int)]) 
        local_argmax_list = np.sort(local_argmax_list, 
                                    order=['value', 'index'])        
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        return local_argmax_list[0][1]
        

    def argmin(self):    
        return np.unravel_index(self.argmin_flat(), self.shape)
    
    def argmax(self):
        return np.unravel_index(self.argmax_flat(), self.shape)
    
    def conjugate(self):
        temp_d2o = self.copy_empty()
        temp_data = np.conj(self.get_local_data())
        temp_d2o.set_local_data(temp_data)
        return temp_d2o

    
    def conj(self):
        return self.conjugate()      
        
    def median(self):
Ultimanet's avatar
Ultimanet committed
659
        about.warnings.cprint(\
660
661
662
663
            "WARNING: The current implementation of median is very expensive!")
        median = np.median(self.get_full_data())
        return median
        
664
    def iscomplex(self):
665
        temp_d2o = self.copy_empty(dtype=np.bool_)
666
667
668
669
        temp_d2o.set_local_data(np.iscomplex(self.data))
        return temp_d2o
    
    def isreal(self):
670
        temp_d2o = self.copy_empty(dtype=np.bool_)
671
672
673
        temp_d2o.set_local_data(np.isreal(self.data))
        return temp_d2o
    
674

675
676
677
678
679
680
681
682
    def all(self):
        local_all = np.all(self.get_local_data())
        global_all = self.distributor._allgather(local_all)
        return np.all(global_all)

    def any(self):
        local_any = np.any(self.get_local_data())
        global_any = self.distributor._allgather(local_any)
683
        return np.any(global_any)
684
        
685
686
687
688
689
690
691
    def unique(self):
        local_unique = np.unique(self.get_local_data())
        global_unique = self.distributor._allgather(local_unique)
        global_unique = np.concatenate(global_unique)
        return np.unique(global_unique)
        
    def bincount(self, weights = None, minlength = None):
692
693
694
        if self.dtype not in [np.dtype('int16'), np.dtype('int32'), 
                np.dtype('int64'),  np.dtype('uint16'), 
                np.dtype('uint32'), np.dtype('uint64')]:
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
            raise TypeError(about._errors.cstring(
                "ERROR: Distributed-data-object must be of integer datatype!"))                                                
                
        minlength = max(self.amax()+1, minlength)
        
        if weights is not None:
            local_weights = self.distributor.extract_local_data(weights).\
                                                                    flatten()
        else:
            local_weights = None
            
        local_counts = np.bincount(self.get_local_data().flatten(),
                                  weights = local_weights,
                                  minlength = minlength)
        list_of_counts = self.distributor._allgather(local_counts)
        counts = np.sum(list_of_counts, axis = 0)
        return counts
                              
713
    
714
    def set_local_data(self, data, hermitian=False, copy=True):
ultimanet's avatar
ultimanet committed
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        """
            Stores data directly in the local data attribute. No distribution 
            is done. The shape of the data must fit the local data attributes
            shape.

            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be stored in the local data attribute.
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
730
731
        self.hermitian = hermitian
        self.data = np.array(data, dtype=self.dtype, copy=copy, order='C')
ultimanet's avatar
ultimanet committed
732
    
733
734
    def set_data(self, data, to_key, from_key=None, local_to_keys=False,
                 hermitian=False, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
        """
            Stores the supplied data in the region which is specified by key. 
            The data is distributed according to the distribution strategy. If
            the individual nodes get different key-arguments. Their data is 
            processed one-by-one.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be stored in.                
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
754
        self.hermitian = hermitian
Ultima's avatar
Ultima committed
755
        self.distributor.disperse_data(data = self.data,
756
                                       to_key = to_key,
Ultima's avatar
Ultima committed
757
                                       data_update = data,
758
759
                                       from_key = from_key,
                                       local_to_keys = local_to_keys,
Ultima's avatar
Ultima committed
760
761
762
763
764
765
766
767
768
                                       copy = copy,
                                       **kwargs)
#                                       
#        (slices, sliceified) = self.__sliceify__(key)        
#        self.distributor.disperse_data(data=self.data, 
#                        to_slices = slices,
#                        data_update = self.__enfold__(data, sliceified),
#                        copy = copy,
#                        *args, **kwargs)        
ultimanet's avatar
ultimanet committed
769
    
770
    def set_full_data(self, data, hermitian=False, copy = True, **kwargs):
ultimanet's avatar
ultimanet committed
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        """
            Distributes the supplied data to the nodes. The shape of data must 
            match the shape of the distributed_data_object.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            
            Notes
            -----
            set_full_data(foo) is equivalent to set_data(foo,slice(None)) but 
            faster.
        
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
790
        self.hermitian = hermitian
791
792
        self.data = self.distributor.distribute_data(data=data, copy = copy, 
                                                     **kwargs)
ultimanet's avatar
ultimanet committed
793

Ultimanet's avatar
Ultimanet committed
794
    def get_local_data(self, key=(slice(None),), copy=True):
ultimanet's avatar
ultimanet committed
795
796
797
798
799
800
801
802
803
804
805
806
807
        """
            Loads data directly from the local data attribute. No consolidation 
            is done. 

            Parameters
            ----------
            key : int, slice, tuple of int or slice
                The key which will be used to access the data. 
            
            Returns
            -------
            self.data[key] : numpy.ndarray
        
Ultimanet's avatar
Ultimanet committed
808
        """
Ultimanet's avatar
Ultimanet committed
809
810
811
812
        if copy == True:
            return self.data[key]        
        if copy == False:
            return self.data
ultimanet's avatar
ultimanet committed
813
        
814
    def get_data(self, key, local_keys=False, **kwargs):
ultimanet's avatar
ultimanet committed
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        """
            Loads data from the region which is specified by key. The data is 
            consolidated according to the distribution strategy. If the 
            individual nodes get different key-arguments, they get individual
            data. 
            
            Parameters
            ----------
        
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be loaded from.                 
            
            Returns
            -------
            global_data[key] : numpy.ndarray
        
        """
833
834
835
836
        return self.distributor.collect_data(self.data, 
                                             key, 
                                             local_keys = local_keys, 
                                             **kwargs)
Ultima's avatar
Ultima committed
837
838
839
#        (slices, sliceified) = self.__sliceify__(key)
#        result = self.distributor.collect_data(self.data, slices, **kwargs)        
#        return self.__defold__(result, sliceified)
ultimanet's avatar
ultimanet committed
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
        
    
    
    def get_full_data(self, target_rank='all'):
        """
            Fully consolidates the distributed data. 
            
            Parameters
            ----------
            target_rank : 'all' (default), int *optional*
                If only one node should recieve the full data, it can be 
                specified here.
            
            Notes
            -----
            get_full_data() is equivalent to get_data(slice(None)) but 
            faster.
        
            Returns
            -------
            None
        """

863
864
        return self.distributor.consolidate_data(self.data, 
                                                 target_rank = target_rank)
ultimanet's avatar
ultimanet committed
865

Ultimanet's avatar
Ultimanet committed
866
867
868
869
870
871
872
    def inject(self, to_slices=(slice(None),), data=None, 
               from_slices=(slice(None),)):
        if data == None:
            return self
        
        self.distributor.inject(self.data, to_slices, data, from_slices)
        
873
874
875
876
877
878
879
880
881
882
883
    def flatten(self, inplace = False):
        flat_shape = (np.prod(self.shape),)
        temp_d2o = self.copy_empty(global_shape = flat_shape)
        flat_data = self.distributor.flatten(self.data, inplace = inplace)
        temp_d2o.set_local_data(data = flat_data)
        if inplace == True:
            self = temp_d2o
            return self
        else:
            return temp_d2o
        
Ultimanet's avatar
Ultimanet committed
884
        
885

ultimanet's avatar
ultimanet committed
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
      
    def save(self, alias, path=None, overwriteQ=True):
        
        """
            Saves a distributed_data_object to disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name for the dataset which is saved within the hdf5 file.
         
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
            
            overwriteQ : Boolean *optional*
                Specifies whether a dataset may be overwritten if it is already
                present in the given hdf5 file or not.
        """
        self.distributor.save_data(self.data, alias, path, overwriteQ)

    def load(self, alias, path=None):
        """
            Loads a distributed_data_object from disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name of the dataset which is loaded from the hdf5 file.
 
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
        """
        self.data = self.distributor.load_data(alias, path)
           

    
924
925
926
927
class _distributor_factory(object):
    def __init__(self):
        self.distributor_store = {}
    
Ultima's avatar
Ultima committed
928
929
930
931
932
933
934
935
936
937
938
#    def parse_kwargs(self, strategy = None, kwargs = {}):
#        return_dict = {}
#        if strategy == 'not':
#            pass
#        ## These strategies use MPI and therefore accept a MPI.comm object
#        if strategy == 'fftw' or strategy == 'equal' or strategy == 'freeform':
#            if kwargs.has_key('comm'):
#                return_dict['comm'] = kwargs['comm']
#
#        return return_dict
    
939
    def parse_kwargs(self, distribution_strategy, comm,
Ultima's avatar
Ultima committed
940
941
942
                   global_data = None, global_shape = None,
                   local_data = None, local_shape = None,
                   alias = None, path = None,
943
                   dtype = None, **kwargs):
Ultima's avatar
Ultima committed
944

945
        return_dict = {}
Ultima's avatar
Ultima committed
946

947
948
949
950
        ## Check that all nodes got the same distribution_strategy
        strat_list = comm.allgather(distribution_strategy)
        assert(all(x == strat_list[0] for x in strat_list))

Ultima's avatar
Ultima committed
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        ## Check for an hdf5 file and open it if given
        if FOUND['h5py'] == True and alias is not None:
            ## set file path            
            file_path = path if (path is not None) else alias 
            ## open hdf5 file
            if FOUND['h5py_parallel'] == True and FOUND['MPI'] == True:
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
                f = h5py.File(file_path, 'r')   
            ## open alias in file
            dset = f[alias] 
        else:
            dset = None


        ## Parse the MPI communicator        
967
968
969
970
971
        if comm is None:
            raise ValueError(about._errors.cstring(
        "ERROR: The distributor needs a MPI communicator object comm!"))
        else:
            return_dict['comm'] = comm
Ultima's avatar
Ultima committed
972
973
974
975
        
        ## Parse the datatype
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
            (dset is not None):
976
            dtype = dset.dtype
Ultima's avatar
Ultima committed
977
        
978
        elif distribution_strategy in ['not', 'equal', 'fftw']: 
Ultima's avatar
Ultima committed
979
            if dtype is None:
980
                if global_data is None:
Ultima's avatar
Ultima committed
981
                    raise ValueError(about._errors.cstring(
982
983
            "ERROR: Neither global_data nor dtype supplied!"))      
                else:
Ultima's avatar
Ultima committed
984
                    try:
985
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
986
                    except(AttributeError):
987
988
989
990
991
992
993
994
995
996
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
                
        elif distribution_strategy in ['freeform']:
            if dtype is None:
                if global_data is None and local_data is None:
                    raise ValueError(about._errors.cstring(
            "ERROR: Neither nor local_data nor dtype supplied!"))      
                else:
Ultima's avatar
Ultima committed
997
                    try:
998
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
999
                    except(AttributeError):
1000
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
1001
            else:
1002
1003
1004
1005
                dtype = np.dtype(dtype)
        
        dtype_list = comm.allgather(dtype)
        assert(all(x == dtype_list[0] for x in dtype_list))
Ultima's avatar
Ultima committed
1006
        return_dict['dtype'] = dtype
1007
        
Ultima's avatar
Ultima committed
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        ## Parse the shape
        ## Case 1: global-type slicer
        if distribution_strategy in ['not', 'equal', 'fftw']:       
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
                    "global_shape nor hdf5 file supplied!"))      
            if global_shape == ():
                raise ValueError(about._errors.cstring(
                    "ERROR: global_shape == () is not valid shape!"))
            
1025
1026
            global_shape_list = comm.allgather(global_shape)
            assert(all(x == global_shape_list[0] for x in global_shape_list))
Ultima's avatar
Ultima committed
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
            return_dict['global_shape'] = global_shape

        ## Case 2: local-type slicer
        elif distribution_strategy in ['freeform']:        
            if local_data is not None and np.isscalar(local_data) == False:
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
                    "local_shape supplied!"))      
            return_dict['local_shape'] = local_shape

1041
1042
1043
        ## Add the name of the distributor if needed
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
Ultima's avatar
Ultima committed
1044
1045
1046
1047
1048
            
        ## close the file-handle
        if dset is not None:
            f.close()

1049
        return return_dict
Ultima's avatar
Ultima committed
1050
1051
1052
            
            
    def hash_arguments(self, distribution_strategy, **kwargs):
1053
        kwargs = kwargs.copy()
1054
1055
1056
        
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
Ultima's avatar
Ultima committed
1057
1058
1059
1060
        
        if kwargs.has_key('global_shape'):
            kwargs['global_shape'] = kwargs['global_shape']
        if kwargs.has_key('local_shape'):
1061
1062
1063
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
Ultima's avatar
Ultima committed
1064
1065
            
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
1066
        kwargs['distribution_strategy'] = distribution_strategy
1067
        
1068
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
1069

1070
    def dictionize_np(self, x):
1071
        dic = x.type.__dict__.items()
1072
1073
1074
1075
1076
1077
        if x is np.float:
            dic[24] = 0 
            dic[29] = 0
            dic[37] = 0
        return frozenset(dic)            
            
1078
    def get_distributor(self, distribution_strategy, comm, **kwargs):
1079
        ## check if the distribution strategy is known
1080
        
Ultima's avatar
Ultima committed
1081
1082
        known_distribution_strategies = ['not', 'equal', 'freeform']
        if FOUND['pyfftw'] == True:
1083
            known_distribution_strategies += ['fftw',]
Ultima's avatar
Ultima committed
1084
        if not distribution_strategy in known_distribution_strategies:
1085
1086
1087
1088
            raise TypeError(about._errors.cstring(
                "ERROR: Unknown distribution strategy supplied."))
                
        ## parse the kwargs
Ultima's avatar
Ultima committed
1089
1090
        parsed_kwargs = self.parse_kwargs(
                                distribution_strategy = distribution_strategy, 
1091
                                comm = comm,
Ultima's avatar
Ultima committed
1092
1093
1094
1095
                                **kwargs)
                                
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
1096
        ## check if the distributors has already been produced in the past
Ultima's avatar
Ultima committed
1097
1098
        if self.distributor_store.has_key(hashed_kwargs):
            return self.distributor_store[hashed_kwargs]
1099
        else:
1100
1101
            ## produce new distributor
            if distribution_strategy == 'not':
Ultima's avatar
Ultima committed
1102
1103
                produced_distributor = _not_distributor(**parsed_kwargs)
            
1104
1105
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
Ultima's avatar
Ultima committed
1106
1107
1108
                                                slicer = _equal_slicer,
                                                **parsed_kwargs)
                                                
1109
1110
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
Ultima's avatar
Ultima committed
1111
1112
1113
1114
1115
1116
1117
                                                slicer = _fftw_slicer,
                                                **parsed_kwargs)    
            elif distribution_strategy == 'freeform':
                produced_distributor = _slicing_distributor(
                                                slicer = _freeform_slicer,
                                                **parsed_kwargs)             
                                                    
1118
            self.distributor_store[hashed_kwargs] = produced_distributor 
Ultima's avatar
Ultima committed
1119
            return self.distributor_store[hashed_kwargs]
1120
1121
1122
            
            
distributor_factory = _distributor_factory()
ultimanet's avatar
ultimanet committed
1123
        
1124
class _slicing_distributor(object):
Ultima's avatar
Ultima committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
    def __init__(self, slicer, name, dtype, comm, **remaining_parsed_kwargs):
        
#        if comm.rank == 0:        
#            if global_shape is None:
#                raise TypeError(about._errors.cstring(
#                    "ERROR: No shape supplied!"))
#            else:
#                self.global_shape = global_shape      
#        else:
#            self.global_shape = None
#            
#        self.global_shape = comm.bcast(self.global_shape, root = 0)
#        self.global_shape = tuple(self.global_shape)
ultimanet's avatar
ultimanet committed
1138
        
Ultima's avatar
Ultima committed
1139
1140
        self.comm = comm
        self.distribution_strategy = name
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
#        if comm.rank == 0:        
#                if dtype is None:        
#                    raise TypeError(about._errors.cstring(
#                    "ERROR: Failed setting datatype! No datatype supplied."))
#                else:
#                    self.dtype = np.dtype(dtype).type
#        else:
#            self.dtype=None
#        self.dtype = comm.bcast(self.dtype, root=0)
        self.dtype = np.dtype(dtype)
1151

ultimanet's avatar
ultimanet committed
1152
        
1153
        self._my_dtype_converter = global_dtype_converter
ultimanet's avatar
ultimanet committed
1154
1155
        
        if not self._my_dtype_converter.known_np_Q(self.dtype):
Ultimanet's avatar
Ultimanet committed
1156
            raise TypeError(about._errors.cstring(\
1157
            "ERROR: The datatype "+str(self.dtype.__repr__())+" is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
1158
1159
1160

        self.mpi_dtype  = self._my_dtype_converter.to_mpi(self.dtype)
        
1161
1162
1163
        #self._local_size = pyfftw.local_size(self.global_shape)
        #self.local_start = self._local_size[2]
        #self.local_end = self.local_start + self._local_size[1]
Ultima's avatar
Ultima committed
1164
1165
        self.slicer = slicer 
        self._local_size = self.slicer(comm = comm, **remaining_parsed_kwargs)
1166
1167
        self.local_start = self._local_size[0]
        self.local_end = self._local_size[1] 
Ultima's avatar
Ultima committed
1168
        self.global_shape = self._local_size[2]
1169
        
ultimanet's avatar
ultimanet committed
1170
1171
1172
1173
        self.local_length = self.local_end-self.local_start        
        self.local_shape = (self.local_length,) + tuple(self.global_shape[1:])
        self.local_dim = np.product(self.local_shape)
        self.local_dim_list = np.empty(comm.size, dtype=np.int)
1174
1175
        comm.Allgather([np.array(self.local_dim,dtype=np.int), MPI.INT],\
            [self.local_dim_list, MPI.INT])
ultimanet's avatar
ultimanet committed
1176
1177
        self.local_dim_offset = np.sum(self.local_dim_list[0:comm.rank])
        
1178
1179
1180
        self.local_slice = np.array([self.local_start, self.local_end,\
            self.local_length, self.local_dim, self.local_dim_offset],\
            dtype=np.int)
ultimanet's avatar
ultimanet committed
1181
1182
1183
        ## collect all local_slices 
        ## [start, stop, length=stop-start, dimension, dimension_offset]
        self.all_local_slices = np.empty((comm.size,5),dtype=np.int)
1184
1185
        comm.Allgather([np.array((self.local_slice,),dtype=np.int), MPI.INT],\
            [self.all_local_slices, MPI.INT])
ultimanet's avatar
ultimanet committed
1186
        
Ultima's avatar
Ultima committed
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
    def initialize_data(self, global_data, local_data, alias, path, hermitian, 
                        copy, **kwargs):
        if FOUND['h5py'] == True and alias is not None:
            local_data = self.load_data(alias = alias, path = path)
            return (local_data, hermitian)
            
        if self.distribution_strategy in ['equal', 'fftw']:    
            if np.isscalar(global_data):
                local_data = np.empty(self.local_shape, dtype = self.dtype)
                local_data.fill(global_data)            
                hermitian = True
            else:
                local_data = self.distribute_data(data = global_data,
                                                  copy = copy)
        elif self.distribution_strategy in ['freeform']:
            if np.isscalar(local_data):
1203
1204
1205
1206
                temp_local_data = np.empty(self.local_shape, 
                                           dtype = self.dtype)
                temp_local_data.fill(local_data)            
                local_data = temp_local_data
Ultima's avatar
Ultima committed
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
                hermitian = True
            elif local_data is None:
                local_data = np.empty(self.local_shape, dtype = self.dtype)
            else:
                local_data = np.array(local_data).astype(
                               self.dtype, copy=copy).reshape(self.local_shape)
        else:
            raise TypeError(about._errors.cstring(
                                        "ERROR: Unknown istribution strategy"))
        return (local_data, hermitian)        
ultimanet's avatar
ultimanet committed
1217
        
1218
1219
1220
1221
1222
1223
    def globalize_flat_index(self, index):
        return int(index)+self.local_dim_offset
        
    def globalize_index(self, index):
        index = np.array(index, dtype=np.int).flatten()
        if index.shape != (len(self.global_shape),):
Ultimanet's avatar
Ultimanet committed
1224
            raise TypeError(about._errors.cstring("ERROR: Length\
1225
1226
1227
1228
1229
1230
1231
1232
1233
                of index tuple does not match the array's shape!"))                 
        globalized_index = index
        globalized_index[0] = index[0] + self.local_start
        ## ensure that the globalized index list is within the bounds
        global_index_memory = globalized_index
        globalized_index = np.clip(globalized_index, 
                                   -np.array(self.global_shape),
                                    np.array(self.global_shape)-1)
        if np.any(global_index_memory != globalized_index):
Ultimanet's avatar
Ultimanet committed
1234
            about.warnings.cprint("WARNING: Indices were clipped!")
1235
1236
1237
1238
1239
1240
1241
1242
1243
        globalized_index = tuple(globalized_index)
        return globalized_index
    
    def _allgather(self, thing, comm=None):
        if comm == None:
            comm = self.comm            
        gathered_things = comm.allgather(thing)
        return gathered_things
    
1244
    def distribute_data(self, data=None, alias=None,
1245
                        path=None, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
1246
1247
1248
1249
1250
        '''
        distribute data checks 
        - whether the data is located on all nodes or only on node 0
        - that the shape of 'data' matches the global_shape
        '''
1251
1252
        
        comm = self.comm            
1253
1254
        rank = comm.Get_rank()
        size = comm.Get_size()        
1255
        local_data_available_Q = np.array((int(data is not None), ))
1256
        data_available_Q = np.empty(size, dtype=int)
1257
1258
        comm.Allgather([local_data_available_Q, MPI.INT], 
                       [data_available_Q, MPI.INT])        
1259
        
Ultima's avatar
Ultima committed
1260
        if data_available_Q[0]==False and FOUND['h5py']:
ultimanet's avatar
ultimanet committed
1261
1262
            try: 
                file_path = path if path != None else alias 
Ultima's avatar
Ultima committed
1263
                if FOUND['h5py_parallel']:
ultimanet's avatar
ultimanet committed
1264
1265
1266
1267
                    f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
                else:
                    f= h5py.File(file_path, 'r')        
                dset = f[alias]
1268
                if dset.shape == self.global_shape and \
1269
                 dset.dtype == self.dtype:
ultimanet's avatar
ultimanet committed
1270
1271
1272
1273
                    temp_data = dset[self.local_start:self.local_end]
                    f.close()
                    return temp_data
                else:
Ultimanet's avatar
Ultimanet committed
1274
                    raise TypeError(about._errors.cstring("ERROR: \
1275
                    Input data has the wrong shape or wrong dtype!"))                 
ultimanet's avatar
ultimanet committed
1276
1277
1278
            except(IOError, AttributeError):
                pass
            
1279
        if np.all(data_available_Q==False):
Ultimanet's avatar
Ultimanet committed
1280
            return np.empty(self.local_shape, dtype=self.dtype, order='C')
ultimanet's avatar
ultimanet committed
1281
        ## if all nodes got data, we assume that it is the right data and 
1282
1283
        ## store it individually. If not, take the data on node 0 and scatter 
        ## it...
ultimanet's avatar
ultimanet committed
1284
        if np.all(data_available_Q):
1285
            return data[self.local_start:self.local_end].astype(self.dtype,\