nifty_mpi_data.py 140 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Theo Steininger
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.


ultimanet's avatar
ultimanet committed
24
25


26

Ultima's avatar
Ultima committed
27
28
##initialize the 'FOUND-packages'-dictionary 
FOUND = {}
ultimanet's avatar
ultimanet committed
29
import numpy as np
Ultimanet's avatar
Ultimanet committed
30
from nifty_about import about
31
from weakref import WeakValueDictionary as weakdict
ultimanet's avatar
ultimanet committed
32
33

try:
34
    from mpi4py import MPI
Ultima's avatar
Ultima committed
35
    FOUND['MPI'] = True
ultimanet's avatar
ultimanet committed
36
except(ImportError): 
37
    import mpi_dummy as MPI
Ultima's avatar
Ultima committed
38
    FOUND['MPI'] = False
ultimanet's avatar
ultimanet committed
39
40
41

try:
    import pyfftw
Ultima's avatar
Ultima committed
42
    FOUND['pyfftw'] = True
ultimanet's avatar
ultimanet committed
43
except(ImportError):       
Ultima's avatar
Ultima committed
44
    FOUND['pyfftw'] = False
ultimanet's avatar
ultimanet committed
45
46

try:
47
    import h5py
Ultima's avatar
Ultima committed
48
49
    FOUND['h5py'] = True
    FOUND['h5py_parallel'] = h5py.get_config().mpi
ultimanet's avatar
ultimanet committed
50
except(ImportError):
Ultima's avatar
Ultima committed
51
52
    FOUND['h5py'] = False
    FOUND['h5py_parallel'] = False
ultimanet's avatar
ultimanet committed
53
54


Ultima's avatar
Ultima committed
55
56
57
58
ALL_DISTRIBUTION_STRATEGIES = ['not', 'equal', 'fftw', 'freeform']
GLOBAL_DISTRIBUTION_STRATEGIES = ['not', 'equal', 'fftw']
LOCAL_DISTRIBUTION_STRATEGIES = ['freeform']
HDF5_DISTRIBUTION_STRATEGIES = ['equal', 'fftw']
59
60
61

COMM = MPI.COMM_WORLD

ultimanet's avatar
ultimanet committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
class distributed_data_object(object):
    """

        NIFTY class for distributed data

        Parameters
        ----------
        global_data : {tuple, list, numpy.ndarray} *at least 1-dimensional*
            Initial data which will be casted to a numpy.ndarray and then 
            stored according to the distribution strategy. The global_data's
            shape overwrites global_shape.
        global_shape : tuple of ints, *optional*
            If no global_data is supplied, global_shape can be used to
            initialize an empty distributed_data_object
        dtype : type, *optional*
            If an explicit dtype is supplied, the given global_data will be 
            casted to it.            
        distribution_strategy : {'fftw' (default), 'not'}, *optional*
            Specifies the way, how global_data will be distributed to the 
            individual nodes. 
            'fftw' follows the distribution strategy of pyfftw.
            'not' does not distribute the data at all. 
            

        Attributes
        ----------
        data : numpy.ndarray
            The numpy.ndarray in which the individual node's data is stored.
        dtype : type
            Data type of the data object.
        distribution_strategy : string
            Name of the used distribution_strategy
        distributor : distributor
            The distributor object which takes care of all distribution and 
            consolidation of the data. 
        shape : tuple of int
            The global shape of the data
            
        Raises
        ------
        TypeError : 
            If the supplied distribution strategy is not known. 
        
    """
106
    def __init__(self, global_data = None, global_shape=None, dtype=None, 
Ultima's avatar
Ultima committed
107
                 local_data=None, local_shape=None,
108
109
                 distribution_strategy='fftw', hermitian=False,
                 alias=None, path=None, comm = MPI.COMM_WORLD, 
110
                 copy = True, *args, **kwargs):
Ultima's avatar
Ultima committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#        
#        ## a given hdf5 file overwrites the other parameters
#        if FOUND['h5py'] == True and alias is not None:
#            ## set file path            
#            file_path = path if (path is not None) else alias 
#            ## open hdf5 file
#            if FOUND['h5py_parallel'] == True and FOUND['MPI'] == True:
#                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
#            else:
#                f= h5py.File(file_path, 'r')        
#            ## open alias in file
#            dset = f[alias] 
#            ## set shape 
#            global_shape = dset.shape
#            ## set dtype
#            dtype = dset.dtype.type

#        ## if no hdf5 path was given, extract global_shape and dtype from 
#        ## the remaining arguments
#        else:        
#            ## an explicitly given dtype overwrites the one from global_data
#            if dtype is None:
#                if global_data is None and local_data is None:
#                    raise ValueError(about._errors.cstring(
#            "ERROR: Neither global_data nor local_data nor dtype supplied!"))      
#                elif global_data is not None:
#                    try:
#                        dtype = global_data.dtype.type
#                    except(AttributeError):
#                        try:
#                            dtype = global_data.dtype
#                        except(AttributeError):
#                            dtype = np.array(global_data).dtype.type
#                elif local_data is not None:
#                    try:
#                        dtype = local_data.dtype.type
#                    except(AttributeError):
#                        try:
#                            dtype = local_data.dtype
#                        except(AttributeError):
#                            dtype = np.array(local_data).dtype.type
#            else:
#                dtype = np.dtype(dtype).type
#            
#            ## an explicitly given global_shape argument is only used if 
#            ## 1. no global_data was supplied, or 
#            ## 2. global_data is a scalar/list of dimension 0.
#            
#            if global_data is not None and np.isscalar(global_data) == False:
#                global_shape = global_data.shape
#            elif global_shape is not None:
#                global_shape = tuple(global_shape)
#                
#            if local_data is not None
#            
##            if global_shape is None:
##                if global_data is None or np.isscalar(global_data):
##                    raise ValueError(about._errors.cstring(
##    "ERROR: Neither non-0-dimensional global_data nor global_shape supplied!"))      
##                global_shape = global_data.shape
##            else:
##                if global_data is None or np.isscalar(global_data):
##                    global_shape = tuple(global_shape)
##                else:
##                    global_shape = global_data.shape
        
        ## TODO: allow init with empty shape
178
179
180
181
182
183
        
        if isinstance(global_data, tuple) or isinstance(global_data, list):
            global_data = np.array(global_data, copy=False)
        if isinstance(local_data, tuple) or isinstance(local_data, list):
            local_data = np.array(local_data, copy=False)
        
184
185
        self.distributor = distributor_factory.get_distributor(
                                distribution_strategy = distribution_strategy,
186
                                comm = comm,
Ultima's avatar
Ultima committed
187
                                global_data = global_data,                                
188
                                global_shape = global_shape,
Ultima's avatar
Ultima committed
189
190
191
192
                                local_data = local_data,
                                local_shape = local_shape,
                                alias = alias,
                                path = path,
193
194
195
                                dtype = dtype,
                                **kwargs)
                                
ultimanet's avatar
ultimanet committed
196
197
198
        self.distribution_strategy = distribution_strategy
        self.dtype = self.distributor.dtype
        self.shape = self.distributor.global_shape
Ultima's avatar
Ultima committed
199
200
        self.local_shape = self.distributor.local_shape
        self.comm = self.distributor.comm
ultimanet's avatar
ultimanet committed
201
        
202
203
        self.init_args = args 
        self.init_kwargs = kwargs
204

Ultima's avatar
Ultima committed
205
206
207
208
        (self.data, self.hermitian) = self.distributor.initialize_data(
                                                     global_data = global_data,
                                                     local_data = local_data,
                                                     alias = alias,
Ultima's avatar
Ultima committed
209
                                                     path = path,
Ultima's avatar
Ultima committed
210
211
                                                     hermitian = hermitian,
                                                     copy = copy)
212
        self.index = d2o_librarian.register(self)
Ultima's avatar
Ultima committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#        ## If a hdf5 path was given, load the data
#        if FOUND['h5py'] == True and alias is not None:
#            self.load(alias = alias, path = path)
#            ## close the file handle
#            f.close()
#            
#        ## If the input data was a scalar, set the whole array to this value
#        elif global_data is not None and np.isscalar(global_data):
#            temp = np.empty(self.distributor.local_shape, dtype = self.dtype)
#            temp.fill(global_data)
#            self.set_local_data(temp)
#            self.hermitian = True
#        else:
#            self.set_full_data(data=global_data, hermitian=hermitian, 
#                               copy = copy, **kwargs)
#            
Ultimanet's avatar
Ultimanet committed
229
230
231
232
    def copy(self, dtype=None, distribution_strategy=None, **kwargs):
        temp_d2o = self.copy_empty(dtype=dtype, 
                                   distribution_strategy=distribution_strategy, 
                                   **kwargs)     
Ultima's avatar
Ultima committed
233
        if distribution_strategy is None or \
Ultimanet's avatar
Ultimanet committed
234
235
236
            distribution_strategy == self.distribution_strategy:
            temp_d2o.set_local_data(self.get_local_data(), copy=True)
        else:
237
            #temp_d2o.set_full_data(self.get_full_data())
Ultima's avatar
Ultima committed
238
            temp_d2o.inject((slice(None),), self, (slice(None),))
239
        temp_d2o.hermitian = self.hermitian
240
241
        return temp_d2o
    
242
    def copy_empty(self, global_shape=None, local_shape=None, dtype=None, 
243
                   distribution_strategy=None, **kwargs):
Ultima's avatar
Ultima committed
244
245
246
247
248
249
250
251
252
253
254
        if self.distribution_strategy == 'not' and \
            distribution_strategy in LOCAL_DISTRIBUTION_STRATEGIES and \
            local_shape == None:
            result = self.copy_empty(global_shape = global_shape,
                                     local_shape = local_shape,
                                     dtype = dtype,
                                     distribution_strategy = 'equal',
                                     **kwargs)
            return result.copy_empty(distribution_strategy = 'freeform')
            
        if global_shape is None:
255
            global_shape = self.shape
Ultima's avatar
Ultima committed
256
257
258
        if local_shape is None:
            local_shape = self.local_shape
        if dtype is None:
259
            dtype = self.dtype
Ultima's avatar
Ultima committed
260
        if distribution_strategy is None:
261
262
263
264
265
            distribution_strategy = self.distribution_strategy

        kwargs.update(self.init_kwargs)
        
        temp_d2o = distributed_data_object(global_shape=global_shape,
266
267
268
                               local_shape = local_shape,
                               dtype = dtype,
                               distribution_strategy = distribution_strategy,
Ultima's avatar
Ultima committed
269
                               comm = self.comm,
270
271
                               *self.init_args,
                               **kwargs)
272
273
        return temp_d2o
    
274
    def apply_scalar_function(self, function, inplace=False, dtype=None):
275
276
        remember_hermitianQ = self.hermitian
        
Ultimanet's avatar
Ultimanet committed
277
278
        if inplace == True:        
            temp = self
Ultima's avatar
Ultima committed
279
            if dtype is not None and self.dtype != np.dtype(dtype):
280
281
282
                about.warnings.cprint(\
            "WARNING: Inplace dtype conversion is not possible!")
                
Ultimanet's avatar
Ultimanet committed
283
        else:
284
            temp = self.copy_empty(dtype=dtype)
Ultimanet's avatar
Ultimanet committed
285

Ultima's avatar
Ultima committed
286
287
288
289
290
291
292
293
294
        if np.prod(self.local_shape) != 0:
            try: 
                temp.data[:] = function(self.data)
            except:
                temp.data[:] = np.vectorize(function)(self.data)
        else:
            ## Noting to do here. The value-empty array
            ## is also geometrically empty
            pass
295
        
296
297
298
299
        if function in (np.exp, np.log):
            temp.hermitian = remember_hermitianQ
        else:
            temp.hermitian = False
Ultimanet's avatar
Ultimanet committed
300
301
302
303
304
305
        return temp
    
    def apply_generator(self, generator):
        self.set_local_data(generator(self.distributor.local_shape))
        self.hermitian = False
            
ultimanet's avatar
ultimanet committed
306
307
308
309
310
311
    def __str__(self):
        return self.data.__str__()
    
    def __repr__(self):
        return '<distributed_data_object>\n'+self.data.__repr__()
    
312
313
    
    def _compare_helper(self, other, op):
314
        result = self.copy_empty(dtype = np.bool_)
Ultimanet's avatar
Ultimanet committed
315
316
317
        ## Case 1: 'other' is a scalar
        ## -> make point-wise comparison
        if np.isscalar(other):
318
319
            result.set_local_data(
                    getattr(self.get_local_data(copy = False), op)(other))
Ultimanet's avatar
Ultimanet committed
320
321
322
323
324
325
326
            return result        

        ## Case 2: 'other' is a numpy array or a distributed_data_object
        ## -> extract the local data and make point-wise comparison
        elif isinstance(other, np.ndarray) or\
        isinstance(other, distributed_data_object):
            temp_data = self.distributor.extract_local_data(other)
327
328
            result.set_local_data(
                getattr(self.get_local_data(copy=False), op)(temp_data))
Ultimanet's avatar
Ultimanet committed
329
330
331
            return result
        
        ## Case 3: 'other' is None
Ultima's avatar
Ultima committed
332
        elif other is None:
Ultimanet's avatar
Ultimanet committed
333
334
335
            return False
        
        ## Case 4: 'other' is something different
336
        ## -> make a numpy casting and make a recursive call
Ultimanet's avatar
Ultimanet committed
337
338
        else:
            temp_other = np.array(other)
339
            return getattr(self, op)(temp_other)
Ultimanet's avatar
Ultimanet committed
340
        
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    def __ne__(self, other):
        return self._compare_helper(other, '__ne__')
        
    def __lt__(self, other):
        return self._compare_helper(other, '__lt__')
            
    def __le__(self, other):
        return self._compare_helper(other, '__le__')

    def __eq__(self, other):

        return self._compare_helper(other, '__eq__')
    def __ge__(self, other):
        return self._compare_helper(other, '__ge__')

    def __gt__(self, other):
        return self._compare_helper(other, '__gt__')

Ultimanet's avatar
Ultimanet committed
360
    def equal(self, other):
Ultimanet's avatar
Ultimanet committed
361
362
363
364
365
366
367
368
369
        if other is None:
            return False
        try:
            assert(self.dtype == other.dtype)
            assert(self.shape == other.shape)
            assert(self.init_args == other.init_args)
            assert(self.init_kwargs == other.init_kwargs)
            assert(self.distribution_strategy == other.distribution_strategy)
            assert(np.all(self.data == other.data))
Ultimanet's avatar
Ultimanet committed
370
        except(AssertionError, AttributeError):
Ultimanet's avatar
Ultimanet committed
371
372
373
374
375
376
377
            return False
        else:
            return True
        

            
    
378
    def __pos__(self):
379
        temp_d2o = self.copy_empty()
380
        temp_d2o.set_local_data(data = self.get_local_data(), copy = True)
381
382
        return temp_d2o
        
ultimanet's avatar
ultimanet committed
383
    def __neg__(self):
384
        temp_d2o = self.copy_empty()
385
386
        temp_d2o.set_local_data(data = self.get_local_data().__neg__(),
                                copy = True) 
ultimanet's avatar
ultimanet committed
387
388
        return temp_d2o
    
389
    def __abs__(self):
Ultimanet's avatar
Ultimanet committed
390
        ## translate complex dtypes
391
392
393
394
395
396
        if self.dtype == np.dtype('complex64'):
            new_dtype = np.dtype('float32')
        elif self.dtype == np.dtype('complex128'):
            new_dtype = np.dtype('float64')
        elif issubclass(self.dtype.type, np.complexfloating):
            new_dtype = np.dtype('float')
Ultimanet's avatar
Ultimanet committed
397
398
399
        else:
            new_dtype = self.dtype
        temp_d2o = self.copy_empty(dtype = new_dtype)
400
401
        temp_d2o.set_local_data(data = self.get_local_data().__abs__(),
                                copy = True) 
402
        return temp_d2o
ultimanet's avatar
ultimanet committed
403
            
Ultima's avatar
Ultima committed
404
405
406
407
408
409
    def _builtin_helper(self, operator, other, inplace=False):
        if isinstance(other, distributed_data_object):
            other_is_real = other.isreal()
        else:
            other_is_real = np.isreal(other)
            
Ultimanet's avatar
Ultimanet committed
410
411
412
413
414
        ## Case 1: other is not a scalar
        if not (np.isscalar(other) or np.shape(other) == (1,)):
##            if self.shape != other.shape:            
##                raise AttributeError(about._errors.cstring(
##                    "ERROR: Shapes do not match!")) 
415
            try:            
416
                hermitian_Q = (other.hermitian and self.hermitian)
417
418
            except(AttributeError):
                hermitian_Q = False
Ultimanet's avatar
Ultimanet committed
419
420
421
            ## extract the local data from the 'other' object
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
Ultimanet's avatar
Ultimanet committed
422
            
423
        ## Case 2: other is a real scalar -> preserve hermitianity
Ultima's avatar
Ultima committed
424
425
        elif other_is_real or (self.dtype not in (np.dtype('complex128'),
                                                  np.dtype('complex256'))):
426
            hermitian_Q = self.hermitian
ultimanet's avatar
ultimanet committed
427
            temp_data = operator(other)
428
429
430
431
        ## Case 3: other is complex
        else:
            hermitian_Q = False
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
432
        ## write the new data into a new distributed_data_object        
433
434
435
        if inplace == True:
            temp_d2o = self
        else:
436
437
            ## use common datatype for self and other
            new_dtype = np.dtype(np.find_common_type((self.dtype,),
438
                                                     (temp_data.dtype,)))
439
440
            temp_d2o = self.copy_empty(
                            dtype = new_dtype)
ultimanet's avatar
ultimanet committed
441
        temp_d2o.set_local_data(data=temp_data)
442
        temp_d2o.hermitian = hermitian_Q
ultimanet's avatar
ultimanet committed
443
        return temp_d2o
444
    """
Ultimanet's avatar
Ultimanet committed
445
    def __inplace_builtin_helper__(self, operator, other):
446
        ## Case 1: other is not a scalar
Ultimanet's avatar
Ultimanet committed
447
448
449
        if not (np.isscalar(other) or np.shape(other) == (1,)):        
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
450
451
452
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other):
            hermitian_Q = self.hermitian
Ultimanet's avatar
Ultimanet committed
453
            temp_data = operator(other)
454
455
456
        ## Case 3: other is complex
        else:
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
457
        self.set_local_data(data=temp_data)
458
        self.hermitian = hermitian_Q
Ultimanet's avatar
Ultimanet committed
459
        return self
460
    """ 
Ultimanet's avatar
Ultimanet committed
461
    
ultimanet's avatar
ultimanet committed
462
    def __add__(self, other):
Ultima's avatar
Ultima committed
463
        return self._builtin_helper(self.get_local_data().__add__, other)
ultimanet's avatar
ultimanet committed
464
465

    def __radd__(self, other):
Ultima's avatar
Ultima committed
466
        return self._builtin_helper(self.get_local_data().__radd__, other)
Ultimanet's avatar
Ultimanet committed
467
468

    def __iadd__(self, other):
Ultima's avatar
Ultima committed
469
        return self._builtin_helper(self.get_local_data().__iadd__, 
470
471
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
472

ultimanet's avatar
ultimanet committed
473
    def __sub__(self, other):
Ultima's avatar
Ultima committed
474
        return self._builtin_helper(self.get_local_data().__sub__, other)
ultimanet's avatar
ultimanet committed
475
476
    
    def __rsub__(self, other):
Ultima's avatar
Ultima committed
477
        return self._builtin_helper(self.get_local_data().__rsub__, other)
ultimanet's avatar
ultimanet committed
478
479
    
    def __isub__(self, other):
Ultima's avatar
Ultima committed
480
        return self._builtin_helper(self.get_local_data().__isub__, 
481
482
                                               other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
483
484
        
    def __div__(self, other):
Ultima's avatar
Ultima committed
485
        return self._builtin_helper(self.get_local_data().__div__, other)
ultimanet's avatar
ultimanet committed
486
    
487
488
489
    def __truediv__(self, other):
        return self.__div__(other)
        
ultimanet's avatar
ultimanet committed
490
    def __rdiv__(self, other):
Ultima's avatar
Ultima committed
491
        return self._builtin_helper(self.get_local_data().__rdiv__, other)
492
493
494
    
    def __rtruediv__(self, other):
        return self.__rdiv__(other)
ultimanet's avatar
ultimanet committed
495

Ultimanet's avatar
Ultimanet committed
496
    def __idiv__(self, other):
Ultima's avatar
Ultima committed
497
        return self._builtin_helper(self.get_local_data().__idiv__, 
498
499
                                               other,
                                               inplace = True)
500
    def __itruediv__(self, other):
501
502
        return self.__idiv__(other)
                                               
ultimanet's avatar
ultimanet committed
503
    def __floordiv__(self, other):
Ultima's avatar
Ultima committed
504
        return self._builtin_helper(self.get_local_data().__floordiv__, 
Ultimanet's avatar
Ultimanet committed
505
                                       other)    
ultimanet's avatar
ultimanet committed
506
    def __rfloordiv__(self, other):
Ultima's avatar
Ultima committed
507
        return self._builtin_helper(self.get_local_data().__rfloordiv__, 
Ultimanet's avatar
Ultimanet committed
508
509
                                       other)
    def __ifloordiv__(self, other):
Ultima's avatar
Ultima committed
510
        return self._builtin_helper(
511
512
                    self.get_local_data().__ifloordiv__, other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
513
514
    
    def __mul__(self, other):
Ultima's avatar
Ultima committed
515
        return self._builtin_helper(self.get_local_data().__mul__, other)
ultimanet's avatar
ultimanet committed
516
517
    
    def __rmul__(self, other):
Ultima's avatar
Ultima committed
518
        return self._builtin_helper(self.get_local_data().__rmul__, other)
ultimanet's avatar
ultimanet committed
519
520

    def __imul__(self, other):
Ultima's avatar
Ultima committed
521
        return self._builtin_helper(self.get_local_data().__imul__, 
522
523
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
524

ultimanet's avatar
ultimanet committed
525
    def __pow__(self, other):
Ultima's avatar
Ultima committed
526
        return self._builtin_helper(self.get_local_data().__pow__, other)
ultimanet's avatar
ultimanet committed
527
528
 
    def __rpow__(self, other):
Ultima's avatar
Ultima committed
529
        return self._builtin_helper(self.get_local_data().__rpow__, other)
ultimanet's avatar
ultimanet committed
530
531

    def __ipow__(self, other):
Ultima's avatar
Ultima committed
532
        return self._builtin_helper(self.get_local_data().__ipow__, 
533
534
                                               other,
                                               inplace = True)
Ultima's avatar
Ultima committed
535
536
537
538
539
540
541
542
    def __mod__(self, other):
        return self._builtin_helper(self.get_local_data().__mod__, other)
    def __rmod__(self, other):
        return self._builtin_helper(self.get_local_data().__rmod__, other)                                               
    def __imod__(self, other):
        return self._builtin_helper(self.get_local_data().__imod__, 
                                               other,
                                               inplace = True)                                               
543
544
    def __len__(self):
        return self.shape[0]
545
    
546
    def get_dim(self):
547
548
        return np.prod(self.shape)
        
549
    def vdot(self, other):
550
        other = self.distributor.extract_local_data(other)
551
552
553
554
555
        local_vdot = np.vdot(self.get_local_data(), other)
        local_vdot_list = self.distributor._allgather(local_vdot)
        global_vdot = np.sum(local_vdot_list)
        return global_vdot
            
Ultimanet's avatar
Ultimanet committed
556

557
    
ultimanet's avatar
ultimanet committed
558
    def __getitem__(self, key):
Ultima's avatar
Ultima committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        return self.get_data(key)
        
#        ## Case 1: key is a boolean array.
#        ## -> take the local data portion from key, use this for data 
#        ## extraction, and then merge the result in a flat numpy array
#        if isinstance(key, np.ndarray):
#            found = 'ndarray'
#            found_boolean = (key.dtype.type == np.bool_)
#        elif isinstance(key, distributed_data_object):
#            found = 'd2o'
#            found_boolean = (key.dtype == np.bool_)
#        else:
#            found = 'other'
#        ## TODO: transfer this into distributor:
#        if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
#            ## extract the data of local relevance
#            local_bool_array = self.distributor.extract_local_data(key)
#            local_results = self.get_local_data(copy=False)[local_bool_array]
#            global_results = self.distributor._allgather(local_results)
#            global_results = np.concatenate(global_results)
#            return global_results            
#            
#        else:
#            return self.get_data(key)
ultimanet's avatar
ultimanet committed
583
584
585
586
    
    def __setitem__(self, key, data):
        self.set_data(data, key)
        
587
    def _contraction_helper(self, function, **kwargs):
Ultima's avatar
Ultima committed
588
589
590
591
592
593
594
        if np.prod(self.data.shape) == 0:
            local = 0
            include = False
        else:
            local = function(self.data, **kwargs)
            include = True

595
        local_list = self.distributor._allgather(local)
Ultima's avatar
Ultima committed
596
597
598
599
600
601
602
603
604
        local_list = np.array(local_list, dtype = np.dtype(local_list[0]))
        include_list = np.array(self.distributor._allgather(include))
        work_list = local_list[include_list]
        if work_list.shape[0] == 0:
            raise ValueError("ERROR: Zero-size array to reduction operation "+
                             "which has no identity")
        else:                             
            result = function(work_list, axis=0)
            return result
605
606
        
    def amin(self, **kwargs):
607
        return self._contraction_helper(np.amin, **kwargs)
608
609

    def nanmin(self, **kwargs):
610
        return self._contraction_helper(np.nanmin, **kwargs)
611
612
        
    def amax(self, **kwargs):
613
        return self._contraction_helper(np.amax, **kwargs)
614
615
    
    def nanmax(self, **kwargs):
616
        return self._contraction_helper(np.nanmax, **kwargs)
Ultimanet's avatar
Ultimanet committed
617
    
618
619
620
621
622
623
    def sum(self, **kwargs):
        return self._contraction_helper(np.sum, **kwargs)

    def prod(self, **kwargs):
        return self._contraction_helper(np.prod, **kwargs)        
        
624
625
    def mean(self, power=1):
        ## compute the local means and the weights for the mean-mean. 
Ultima's avatar
Ultima committed
626
627
628
629
630
631
632
        if np.prod(self.data.shape) == 0:
            local_mean = 0
            include = False
        else:
            local_mean = np.mean(self.data**power)
            include = True
        
633
634
        local_weight = np.prod(self.data.shape)
        ## collect the local means and cast the result to a ndarray
Ultima's avatar
Ultima committed
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        local_mean_list = self.distributor._allgather(local_mean)
        local_weight_list = self.distributor._allgather(local_weight)
        
        local_mean_list =np.array(local_mean_list, 
                                  dtype = np.dtype(local_mean_list[0]))
        local_weight_list = np.array(local_weight_list)                                  
        ## extract the parts from the non-empty nodes
        include_list = np.array(self.distributor._allgather(include))
        work_mean_list = local_mean_list[include_list]
        work_weight_list = local_weight_list[include_list]
        if work_mean_list.shape[0] == 0:
            raise ValueError("ERROR:  Mean of empty slice.")
        else:                             
            ## compute the denominator for the weighted mean-mean               
            global_weight = np.sum(work_weight_list)
            ## compute the numerator
            numerator = np.sum(work_mean_list * work_weight_list)
            global_mean = numerator/global_weight
            return global_mean
654
655
656
657
658
659
660
661
662

    def var(self):
        mean_of_the_square = self.mean(power=2)
        square_of_the_mean = self.mean()**2
        return mean_of_the_square - square_of_the_mean
    
    def std(self):
        return np.sqrt(self.var())
        
663
664
665
666
#    def _argmin_argmax_flat_helper(self, function):
#        local_argmin = function(self.data)
#        local_argmin_value = self.data[np.unravel_index(local_argmin, 
#                                                        self.data.shape)]
667
#        globalized_local_argmin = self.distributor.globalize_flat_index(local_argmin) 
668
669
670
671
672
673
#        local_argmin_list = self.distributor._allgather((local_argmin_value, 
#                                                         globalized_local_argmin))
#        local_argmin_list = np.array(local_argmin_list, dtype=[('value', int),
#                                                               ('index', int)])    
#        return local_argmin_list
#        
Ultima's avatar
Ultima committed
674
675
676
677
678
679
680
681
682
683
684
    def argmin(self):
        if np.prod(self.data.shape) == 0:
            local_argmin = np.nan
            local_argmin_value = np.nan
            globalized_local_argmin = np.nan
        else:
            local_argmin = np.argmin(self.data)
            local_argmin_value = self.data[np.unravel_index(local_argmin, 
                                                            self.data.shape)]
        
            globalized_local_argmin = self.distributor.globalize_flat_index(
685
                                                                local_argmin)  
686
        local_argmin_list = self.distributor._allgather((local_argmin_value, 
687
688
                                                    globalized_local_argmin))
        local_argmin_list = np.array(local_argmin_list, dtype=[
Ultima's avatar
Ultima committed
689
690
                                        ('value', np.dtype('complex128')),
                                        ('index', np.dtype('float'))]) 
691
692
        local_argmin_list = np.sort(local_argmin_list, 
                                    order=['value', 'index'])        
Ultima's avatar
Ultima committed
693
        return np.int(local_argmin_list[0][1])
694
    
Ultima's avatar
Ultima committed
695
696
697
698
699
700
701
702
703
704
    def argmax(self):
        if np.prod(self.data.shape) == 0:
            local_argmax = np.nan
            local_argmax_value = np.nan
            globalized_local_argmax = np.nan
        else:
            local_argmax = np.argmax(self.data)
            local_argmax_value = -self.data[np.unravel_index(local_argmax, 
                                                            self.data.shape)]
            globalized_local_argmax = self.distributor.globalize_flat_index(
705
                                                                local_argmax)
706
        local_argmax_list = self.distributor._allgather((local_argmax_value, 
707
708
                                                    globalized_local_argmax))
        local_argmax_list = np.array(local_argmax_list, dtype=[
Ultima's avatar
Ultima committed
709
710
                                        ('value', np.dtype('complex128')),
                                        ('index', np.dtype('float'))]) 
711
712
        local_argmax_list = np.sort(local_argmax_list, 
                                    order=['value', 'index'])        
Ultima's avatar
Ultima committed
713
        return np.int(local_argmax_list[0][1])
714
715
        

Ultima's avatar
Ultima committed
716
717
    def argmin_nonflat(self):    
        return np.unravel_index(self.argmin(), self.shape)
718
    
Ultima's avatar
Ultima committed
719
720
    def argmax_nonflat(self):
        return np.unravel_index(self.argmax(), self.shape)
721
722
723
724
725
726
727
728
729
730
731
732
    
    def conjugate(self):
        temp_d2o = self.copy_empty()
        temp_data = np.conj(self.get_local_data())
        temp_d2o.set_local_data(temp_data)
        return temp_d2o

    
    def conj(self):
        return self.conjugate()      
        
    def median(self):
Ultimanet's avatar
Ultimanet committed
733
        about.warnings.cprint(\
734
735
736
737
            "WARNING: The current implementation of median is very expensive!")
        median = np.median(self.get_full_data())
        return median
        
738
    def iscomplex(self):
Ultima's avatar
Ultima committed
739
        temp_d2o = self.copy_empty(dtype=np.dtype('bool'))
740
741
742
743
        temp_d2o.set_local_data(np.iscomplex(self.data))
        return temp_d2o
    
    def isreal(self):
Ultima's avatar
Ultima committed
744
        temp_d2o = self.copy_empty(dtype=np.dtype('bool'))
745
746
747
        temp_d2o.set_local_data(np.isreal(self.data))
        return temp_d2o
    
748

749
750
751
752
753
754
755
756
    def all(self):
        local_all = np.all(self.get_local_data())
        global_all = self.distributor._allgather(local_all)
        return np.all(global_all)

    def any(self):
        local_any = np.any(self.get_local_data())
        global_any = self.distributor._allgather(local_any)
757
        return np.any(global_any)
758
        
759
760
761
762
763
764
765
    def unique(self):
        local_unique = np.unique(self.get_local_data())
        global_unique = self.distributor._allgather(local_unique)
        global_unique = np.concatenate(global_unique)
        return np.unique(global_unique)
        
    def bincount(self, weights = None, minlength = None):
766
767
768
        if self.dtype not in [np.dtype('int16'), np.dtype('int32'), 
                np.dtype('int64'),  np.dtype('uint16'), 
                np.dtype('uint32'), np.dtype('uint64')]:
769
770
771
772
773
774
775
776
777
778
779
780
781
782
            raise TypeError(about._errors.cstring(
                "ERROR: Distributed-data-object must be of integer datatype!"))                                                
                
        minlength = max(self.amax()+1, minlength)
        
        if weights is not None:
            local_weights = self.distributor.extract_local_data(weights).\
                                                                    flatten()
        else:
            local_weights = None
            
        local_counts = np.bincount(self.get_local_data().flatten(),
                                  weights = local_weights,
                                  minlength = minlength)
Ultima's avatar
Ultima committed
783
784
785
786
787
788
789
790
791
792
793
        if self.distribution_strategy == 'not':
            return local_counts
        else:
            list_of_counts = self.distributor._allgather(local_counts)
            counts = np.sum(list_of_counts, axis = 0)
            return counts
                                  

    def where(self):
        return self.distributor.where(self.data)
        
794
    def set_local_data(self, data, hermitian=False, copy=True):
ultimanet's avatar
ultimanet committed
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        """
            Stores data directly in the local data attribute. No distribution 
            is done. The shape of the data must fit the local data attributes
            shape.

            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be stored in the local data attribute.
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
810
        self.hermitian = hermitian
Ultima's avatar
Ultima committed
811
812
813
814
815
        if copy == True:
            self.data[:] = data
        else:
            self.data = np.array(data, dtype=self.dtype, 
                            copy=False, order='C').reshape(self.local_shape)
ultimanet's avatar
ultimanet committed
816
    
Ultima's avatar
Ultima committed
817
    def set_data(self, data, to_key, from_key=None, local_keys=False,
818
                 hermitian=False, copy=True, **kwargs):
ultimanet's avatar
ultimanet committed
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
        """
            Stores the supplied data in the region which is specified by key. 
            The data is distributed according to the distribution strategy. If
            the individual nodes get different key-arguments. Their data is 
            processed one-by-one.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be stored in.                
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
838
        self.hermitian = hermitian
Ultima's avatar
Ultima committed
839
        self.distributor.disperse_data(data = self.data,
840
                                       to_key = to_key,
Ultima's avatar
Ultima committed
841
                                       data_update = data,
842
                                       from_key = from_key,
Ultima's avatar
Ultima committed
843
                                       local_keys = local_keys,
Ultima's avatar
Ultima committed
844
845
846
847
848
849
850
851
852
                                       copy = copy,
                                       **kwargs)
#                                       
#        (slices, sliceified) = self.__sliceify__(key)        
#        self.distributor.disperse_data(data=self.data, 
#                        to_slices = slices,
#                        data_update = self.__enfold__(data, sliceified),
#                        copy = copy,
#                        *args, **kwargs)        
ultimanet's avatar
ultimanet committed
853
    
854
    def set_full_data(self, data, hermitian=False, copy = True, **kwargs):
ultimanet's avatar
ultimanet committed
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        """
            Distributes the supplied data to the nodes. The shape of data must 
            match the shape of the distributed_data_object.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            
            Notes
            -----
            set_full_data(foo) is equivalent to set_data(foo,slice(None)) but 
            faster.
        
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
874
        self.hermitian = hermitian
875
876
        self.data = self.distributor.distribute_data(data=data, copy = copy, 
                                                     **kwargs)
ultimanet's avatar
ultimanet committed
877

Ultimanet's avatar
Ultimanet committed
878
    def get_local_data(self, key=(slice(None),), copy=True):
ultimanet's avatar
ultimanet committed
879
880
881
882
883
884
885
886
887
888
889
890
891
        """
            Loads data directly from the local data attribute. No consolidation 
            is done. 

            Parameters
            ----------
            key : int, slice, tuple of int or slice
                The key which will be used to access the data. 
            
            Returns
            -------
            self.data[key] : numpy.ndarray
        
Ultimanet's avatar
Ultimanet committed
892
        """
Ultimanet's avatar
Ultimanet committed
893
894
895
896
        if copy == True:
            return self.data[key]        
        if copy == False:
            return self.data
ultimanet's avatar
ultimanet committed
897
        
898
    def get_data(self, key, local_keys=False, **kwargs):
ultimanet's avatar
ultimanet committed
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
        """
            Loads data from the region which is specified by key. The data is 
            consolidated according to the distribution strategy. If the 
            individual nodes get different key-arguments, they get individual
            data. 
            
            Parameters
            ----------
        
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be loaded from.                 
            
            Returns
            -------
            global_data[key] : numpy.ndarray
        
        """
Ultima's avatar
Ultima committed
917
918
919
920
921
922
923
924
925
926
927
928
        if key is None:
            return self.copy()
        elif isinstance(key, slice):
            if key == slice(None):
                return self.copy()
        elif isinstance(key, tuple):
            try:
                if all(x == slice(None) for x in key):
                    return self.copy()
            except(ValueError):
                pass

929
930
931
932
        return self.distributor.collect_data(self.data, 
                                             key, 
                                             local_keys = local_keys, 
                                             **kwargs)
Ultima's avatar
Ultima committed
933
934
935
#        (slices, sliceified) = self.__sliceify__(key)
#        result = self.distributor.collect_data(self.data, slices, **kwargs)        
#        return self.__defold__(result, sliceified)
ultimanet's avatar
ultimanet committed
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        
    
    
    def get_full_data(self, target_rank='all'):
        """
            Fully consolidates the distributed data. 
            
            Parameters
            ----------
            target_rank : 'all' (default), int *optional*
                If only one node should recieve the full data, it can be 
                specified here.
            
            Notes
            -----
            get_full_data() is equivalent to get_data(slice(None)) but 
            faster.
        
            Returns
            -------
            None
        """

959
960
        return self.distributor.consolidate_data(self.data, 
                                                 target_rank = target_rank)
ultimanet's avatar
ultimanet committed
961

Ultima's avatar
Ultima committed
962
963
964
    def inject(self, to_key=(slice(None),), data=None, 
               from_key=(slice(None),)):
        if data is None:
Ultimanet's avatar
Ultimanet committed
965
            return self
Ultima's avatar
Ultima committed
966
        self.distributor.inject(self.data, to_key, data, from_key)
Ultimanet's avatar
Ultimanet committed
967
        
968
    def flatten(self, inplace = False):
Ultima's avatar
Ultima committed
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        flat_data = self.distributor.flatten(self.data, inplace = inplace)        
        
        flat_global_shape = (np.prod(self.shape),)
        flat_local_shape = np.shape(flat_data)
        
        ## Try to keep the distribution strategy. Therefore 
        ## create an empty copy of self which has the new shape
        temp_d2o = self.copy_empty(global_shape = flat_global_shape,
                                   local_shape = flat_local_shape)
        ## Check if the local shapes match.
        if temp_d2o.local_shape == flat_local_shape:
            work_d2o = temp_d2o
        ## if the shapes do not match, create a freeform d2o
        else:
            work_d2o = self.copy_empty(local_shape = flat_local_shape,
                            distribution_strategy = 'freeform')

        ## Feed the work_d2o with the flat data
        work_d2o.set_local_data(data = flat_data,
                                copy = False)                                    
        
990
        if inplace == True:
Ultima's avatar
Ultima committed
991
            self = work_d2o
992
993
            return self
        else:
Ultima's avatar
Ultima committed
994
995
996
997
998
999
1000
1001
1002
            return work_d2o
#        
#        flat_data = self.distributor.flatten(self.data, inplace = inplace)
#        temp_d2o.set_local_data(data = flat_data)
#        if inplace == True:
#            self = temp_d2o
#            return self
#        else:
#            return temp_d2o
1003
        
Ultimanet's avatar
Ultimanet committed
1004
        
1005

ultimanet's avatar
ultimanet committed
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
      
    def save(self, alias, path=None, overwriteQ=True):
        
        """
            Saves a distributed_data_object to disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name for the dataset which is saved within the hdf5 file.
         
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
            
            overwriteQ : Boolean *optional*
                Specifies whether a dataset may be overwritten if it is already
                present in the given hdf5 file or not.
        """
        self.distributor.save_data(self.data, alias, path, overwriteQ)

    def load(self, alias, path=None):
        """
            Loads a distributed_data_object from disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name of the dataset which is loaded from the hdf5 file.
 
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
        """
        self.data = self.distributor.load_data(alias, path)
           

    
1044
1045
1046
1047
class _distributor_factory(object):
    def __init__(self):
        self.distributor_store = {}
    
Ultima's avatar
Ultima committed
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
#    def parse_kwargs(self, strategy = None, kwargs = {}):
#        return_dict = {}
#        if strategy == 'not':
#            pass
#        ## These strategies use MPI and therefore accept a MPI.comm object
#        if strategy == 'fftw' or strategy == 'equal' or strategy == 'freeform':
#            if kwargs.has_key('comm'):
#                return_dict['comm'] = kwargs['comm']
#
#        return return_dict
    
1059
    def parse_kwargs(self, distribution_strategy, comm,
Ultima's avatar
Ultima committed
1060
1061
1062
                   global_data = None, global_shape = None,
                   local_data = None, local_shape = None,
                   alias = None, path = None,
1063
                   dtype = None, **kwargs):
Ultima's avatar
Ultima committed
1064

1065
        return_dict = {}
Ultima's avatar
Ultima committed
1066

1067
1068
        ## Check that all nodes got the same distribution_strategy
        strat_list = comm.allgather(distribution_strategy)
Ultima's avatar
Ultima committed
1069
1070
1071
1072
        if all(x == strat_list[0] for x in strat_list) == False:
            raise ValueError(about._errors.cstring(
                "ERROR: The distribution-strategy must be the same on "+
                "all nodes!"))
1073

Ultima's avatar
Ultima committed
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
        ## Check for an hdf5 file and open it if given
        if FOUND['h5py'] == True and alias is not None:
            ## set file path            
            file_path = path if (path is not None) else alias 
            ## open hdf5 file
            if FOUND['h5py_parallel'] == True and FOUND['MPI'] == True:
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
            else:
                f = h5py.File(file_path, 'r')   
            ## open alias in file
            dset = f[alias] 
        else:
            dset = None


        ## Parse the MPI communicator        
1090
1091
1092
1093
1094
        if comm is None:
            raise ValueError(about._errors.cstring(
        "ERROR: The distributor needs a MPI communicator object comm!"))
        else:
            return_dict['comm'] = comm
Ultima's avatar
Ultima committed
1095
1096
1097
1098
        
        ## Parse the datatype
        if distribution_strategy in ['not', 'equal', 'fftw'] and \
            (dset is not None):
1099
            dtype = dset.dtype
Ultima's avatar
Ultima committed
1100
        
1101
        elif distribution_strategy in ['not', 'equal', 'fftw']: 
Ultima's avatar
Ultima committed
1102
            if dtype is None:
1103
                if global_data is None:
Ultima's avatar
Ultima committed
1104
1105
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
1106
                else:
Ultima's avatar
Ultima committed
1107
                    try:
1108
                        dtype = global_data.dtype
Ultima's avatar
Ultima committed
1109
                    except(AttributeError):
1110
1111
1112
1113
1114
1115
                        dtype = np.array(global_data).dtype
            else:
                dtype = np.dtype(dtype)
                
        elif distribution_strategy in ['freeform']:
            if dtype is None:
Ultima's avatar
Ultima committed
1116
1117
1118
                if isinstance(global_data, distributed_data_object):
                    dtype = global_data.dtype
                elif local_data is not None:
Ultima's avatar
Ultima committed
1119
                    try:
1120
                        dtype = local_data.dtype
Ultima's avatar
Ultima committed
1121
                    except(AttributeError):
1122
                        dtype = np.array(local_data).dtype
Ultima's avatar
Ultima committed
1123
1124
1125
1126
                else:
                    dtype = np.dtype('float64')
                    about.infos.cprint('INFO: dtype set was set to default.')
                
Ultima's avatar
Ultima committed
1127
            else:
1128
1129
                dtype = np.dtype(dtype)
        dtype_list = comm.allgather(dtype)
Ultima's avatar
Ultima committed
1130
1131
1132
        if all(x == dtype_list[0] for x in dtype_list) == False:
            raise ValueError(about._errors.cstring(
            "ERROR: The given dtype must be the same on all nodes!"))
Ultima's avatar
Ultima committed
1133
        return_dict['dtype'] = dtype
1134
        
Ultima's avatar
Ultima committed
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
        ## Parse the shape
        ## Case 1: global-type slicer
        if distribution_strategy in ['not', 'equal', 'fftw']:       
            if dset is not None:
                global_shape = dset.shape
            elif global_data is not None and np.isscalar(global_data) == False:
                global_shape = global_data.shape
            elif global_shape is not None:
                global_shape = tuple(global_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional global_data nor " +
                    "global_shape nor hdf5 file supplied!"))      
            if global_shape == ():
                raise ValueError(about._errors.cstring(
Ultima's avatar
Ultima committed
1150
                    "ERROR: global_shape == () is not a valid shape!"))
Ultima's avatar
Ultima committed
1151
            
1152
            global_shape_list = comm.allgather(global_shape)
Ultima's avatar
Ultima committed
1153
1154
1155
            if not all(x == global_shape_list[0] for x in global_shape_list):
                raise ValueError(about._errors.cstring(
                    "ERROR: The global_shape must be the same on all nodes!"))
Ultima's avatar
Ultima committed
1156
1157
1158
            return_dict['global_shape'] = global_shape

        ## Case 2: local-type slicer
Ultima's avatar
Ultima committed
1159
1160
1161
1162
        elif distribution_strategy in ['freeform']:
            if isinstance(global_data, distributed_data_object):
                local_shape = global_data.local_shape
            elif local_data is not None and np.isscalar(local_data) == False:
Ultima's avatar
Ultima committed
1163
1164
1165
1166
1167
1168
                local_shape = local_data.shape
            elif local_shape is not None:
                local_shape = tuple(local_shape)
            else:
                raise ValueError(about._errors.cstring(
                    "ERROR: Neither non-0-dimensional local_data nor " +
Ultima's avatar
Ultima committed
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
                    "local_shape nor global d2o supplied!"))      
            if local_shape == ():
                raise ValueError(about._errors.cstring(
                    "ERROR: local_shape == () is not a valid shape!"))  
            
            local_shape_list = comm.allgather(local_shape[1:])
            cleared_set = set(local_shape_list)
            cleared_set.discard(())
            if len(cleared_set) > 1:
            #if not any(x == () for x in map(np.shape, local_shape_list)):
            #if not all(x == local_shape_list[0] for x in local_shape_list):
                raise ValueError(about._errors.cstring(
                    "ERROR: All but the first entry of local_shape must be "+
                    "the same on all nodes!"))
Ultima's avatar
Ultima committed
1183
1184
            return_dict['local_shape'] = local_shape

1185
1186
1187
        ## Add the name of the distributor if needed
        if distribution_strategy in ['equal', 'fftw', 'freeform']:
            return_dict['name'] = distribution_strategy
Ultima's avatar
Ultima committed
1188
1189
1190
1191
1192
            
        ## close the file-handle
        if dset is not None:
            f.close()

1193
        return return_dict
Ultima's avatar
Ultima committed
1194
1195
1196
            
            
    def hash_arguments(self, distribution_strategy, **kwargs):
1197
        kwargs = kwargs.copy()
1198
1199
1200
        
        comm = kwargs['comm']
        kwargs['comm'] = id(comm)
Ultima's avatar
Ultima committed
1201
1202
1203
1204
        
        if kwargs.has_key('global_shape'):
            kwargs['global_shape'] = kwargs['global_shape']
        if kwargs.has_key('local_shape'):
1205
1206
1207
            local_shape = kwargs['local_shape']
            local_shape_list = comm.allgather(local_shape)
            kwargs['local_shape'] = tuple(local_shape_list)
Ultima's avatar
Ultima committed
1208
1209
            
        kwargs['dtype'] = self.dictionize_np(kwargs['dtype'])
1210
        kwargs['distribution_strategy'] = distribution_strategy
1211
        
1212
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
1213

1214
    def dictionize_np(self, x):
1215
        dic = x.type.__dict__.items()
1216
1217
1218
1219
1220
1221
        if x is np.float:
            dic[24] = 0 
            dic[29] = 0
            dic[37] = 0
        return frozenset(dic)            
            
1222
    def get_distributor(self, distribution_strategy, comm, **kwargs):
1223
        ## check if the distribution strategy is known
1224
        
Ultima's avatar
Ultima committed
1225
1226
        known_distribution_strategies = ['not', 'equal', 'freeform']
        if FOUND['pyfftw'] == True:
1227
            known_distribution_strategies += ['fftw',]
Ultima's avatar
Ultima committed
1228
        if not distribution_strategy in known_distribution_strategies:
1229
1230
1231
1232
            raise TypeError(about._errors.cstring(
                "ERROR: Unknown distribution strategy supplied."))
                
        ## parse the kwargs
Ultima's avatar
Ultima committed
1233
1234
        parsed_kwargs = self.parse_kwargs(
                                distribution_strategy = distribution_strategy, 
1235
                                comm = comm,
Ultima's avatar
Ultima committed
1236
1237
1238
1239
                                **kwargs)
                                
        hashed_kwargs = self.hash_arguments(distribution_strategy,
                                            **parsed_kwargs)
1240
        ## check if the distributors has already been produced in the past
Ultima's avatar
Ultima committed
1241
1242
        if self.distributor_store.has_key(hashed_kwargs):
            return self.distributor_store[hashed_kwargs]
1243
        else:
1244
1245
            ## produce new distributor
            if distribution_strategy == 'not':