nifty_mpi_data.py 76.8 KB
Newer Older
ultimanet's avatar
ultimanet committed
1
# -*- coding: utf-8 -*-
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
## NIFTY (Numerical Information Field Theory) has been developed at the
## Max-Planck-Institute for Astrophysics.
##
## Copyright (C) 2015 Max-Planck-Society
##
## Author: Theo Steininger
## Project homepage: <http://www.mpa-garching.mpg.de/ift/nifty/>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
## See the GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.


ultimanet's avatar
ultimanet committed
24
25


26

27
##initialize the 'found-packages'-dictionary 
28
found = {}
ultimanet's avatar
ultimanet committed
29
import numpy as np
Ultimanet's avatar
Ultimanet committed
30
from nifty_about import about
ultimanet's avatar
ultimanet committed
31
32

try:
33
    from mpi4py import MPI
34
    found['MPI'] = True
ultimanet's avatar
ultimanet committed
35
except(ImportError): 
36
    import mpi_dummy as MPI
37
    found['MPI'] = False
ultimanet's avatar
ultimanet committed
38
39
40
41
42
43
44
45

try:
    import pyfftw
    found['pyfftw'] = True
except(ImportError):       
    found['pyfftw'] = False

try:
46
    import h5py
ultimanet's avatar
ultimanet committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    found['h5py'] = True
    found['h5py_parallel'] = h5py.get_config().mpi
except(ImportError):
    found['h5py'] = False
    found['h5py_parallel'] = False


class distributed_data_object(object):
    """

        NIFTY class for distributed data

        Parameters
        ----------
        global_data : {tuple, list, numpy.ndarray} *at least 1-dimensional*
            Initial data which will be casted to a numpy.ndarray and then 
            stored according to the distribution strategy. The global_data's
            shape overwrites global_shape.
        global_shape : tuple of ints, *optional*
            If no global_data is supplied, global_shape can be used to
            initialize an empty distributed_data_object
        dtype : type, *optional*
            If an explicit dtype is supplied, the given global_data will be 
            casted to it.            
        distribution_strategy : {'fftw' (default), 'not'}, *optional*
            Specifies the way, how global_data will be distributed to the 
            individual nodes. 
            'fftw' follows the distribution strategy of pyfftw.
            'not' does not distribute the data at all. 
            

        Attributes
        ----------
        data : numpy.ndarray
            The numpy.ndarray in which the individual node's data is stored.
        dtype : type
            Data type of the data object.
        distribution_strategy : string
            Name of the used distribution_strategy
        distributor : distributor
            The distributor object which takes care of all distribution and 
            consolidation of the data. 
        shape : tuple of int
            The global shape of the data
            
        Raises
        ------
        TypeError : 
            If the supplied distribution strategy is not known. 
        
    """
98
99
100
    def __init__(self, global_data = None, global_shape=None, dtype=None, 
                 distribution_strategy='fftw', hermitian=False,
                 alias=None, path=None, comm = MPI.COMM_WORLD, 
101
                 copy = True, *args, **kwargs):
102
103
104
105
106
107
108
109
        
        ## a given hdf5 file overwrites the other parameters
        if found['h5py'] == True and alias is not None:
            ## set file path            
            file_path = path if (path is not None) else alias 
            ## open hdf5 file
            if found['h5py_parallel'] == True and found['MPI'] == True:
                f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
110
            else:
111
112
113
114
115
116
117
                f= h5py.File(file_path, 'r')        
            ## open alias in file
            dset = f[alias] 
            ## set shape 
            global_shape = dset.shape
            ## set dtype
            dtype = dset.dtype.type
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        ## if no hdf5 path was given, extract global_shape and dtype from 
        ## the remaining arguments
        else:        
            ## an explicitly given dtype overwrites the one from global_data
            if dtype is None:
                if global_data is None:
                    raise ValueError(about._errors.cstring(
                        "ERROR: Neither global_data nor dtype supplied!"))      
                try:
                    dtype = global_data.dtype.type
                except(AttributeError):
                    try:
                        dtype = global_data.dtype
                    except(AttributeError):
                        dtype = np.array(global_data).dtype.type
            else:
                dtype = dtype
            
            ## an explicitly given global_shape argument is only used if 
            ## 1. no global_data was supplied, or 
            ## 2. global_data is a scalar/list of dimension 0.
            if global_shape is None:
                if global_data is None or np.isscalar(global_data):
                    raise ValueError(about._errors.cstring(
    "ERROR: Neither non-0-dimensional global_data nor global_shape supplied!"))      
                global_shape = global_data.shape
            else:
                if global_data is None or np.isscalar(global_data):
                    global_shape = tuple(global_shape)
                else:
                    global_shape = global_data.shape
Ultimanet's avatar
Ultimanet committed
150

Ultimanet's avatar
Ultimanet committed
151

152
153
154
155
156
157
        self.distributor = distributor_factory.get_distributor(
                                distribution_strategy = distribution_strategy,
                                global_shape = global_shape,
                                dtype = dtype,
                                **kwargs)
                                
ultimanet's avatar
ultimanet committed
158
159
160
161
        self.distribution_strategy = distribution_strategy
        self.dtype = self.distributor.dtype
        self.shape = self.distributor.global_shape
        
162
163
        self.init_args = args 
        self.init_kwargs = kwargs
164
165
166
167
168
169
170
171


        ## If a hdf5 path was given, load the data
        if found['h5py'] == True and alias is not None:
            self.load(alias = alias, path = path)
            ## close the file handle
            f.close()
            
172
        ## If the input data was a scalar, set the whole array to this value
173
        elif global_data != None and np.isscalar(global_data):
Ultimanet's avatar
Ultimanet committed
174
175
176
            temp = np.empty(self.distributor.local_shape)
            temp.fill(global_data)
            self.set_local_data(temp)
177
            self.hermitian = True
178
179
        else:
            self.set_full_data(data=global_data, hermitian=hermitian, 
180
                               copy = copy, **kwargs)
181
            
Ultimanet's avatar
Ultimanet committed
182
183
184
185
186
187
188
189
    def copy(self, dtype=None, distribution_strategy=None, **kwargs):
        temp_d2o = self.copy_empty(dtype=dtype, 
                                   distribution_strategy=distribution_strategy, 
                                   **kwargs)     
        if distribution_strategy == None or \
            distribution_strategy == self.distribution_strategy:
            temp_d2o.set_local_data(self.get_local_data(), copy=True)
        else:
190
191
            #temp_d2o.set_full_data(self.get_full_data())
            temp_d2o.inject([slice(None),], self, [slice(None),])
192
        temp_d2o.hermitian = self.hermitian
193
194
        return temp_d2o
    
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def copy_empty(self, global_shape=None, dtype=None, 
                   distribution_strategy=None, **kwargs):
        if global_shape == None:
            global_shape = self.shape
        if dtype == None:
            dtype = self.dtype
        if distribution_strategy == None:
            distribution_strategy = self.distribution_strategy

        kwargs.update(self.init_kwargs)
        
        temp_d2o = distributed_data_object(global_shape=global_shape,
                                           dtype=dtype,
                                           distribution_strategy=distribution_strategy,
209
                                           *self.init_args,
210
                                           **kwargs)
211
212
        return temp_d2o
    
213
    def apply_scalar_function(self, function, inplace=False, dtype=None):
214
215
        remember_hermitianQ = self.hermitian
        
Ultimanet's avatar
Ultimanet committed
216
217
        if inplace == True:        
            temp = self
218
219
220
221
            if dtype != None and self.dtype != dtype:
                about.warnings.cprint(\
            "WARNING: Inplace dtype conversion is not possible!")
                
Ultimanet's avatar
Ultimanet committed
222
        else:
223
            temp = self.copy_empty(dtype=dtype)
Ultimanet's avatar
Ultimanet committed
224
225
226
227
228

        try: 
            temp.data[:] = function(self.data)
        except:
            temp.data[:] = np.vectorize(function)(self.data)
229
        
230
231
232
233
        if function in (np.exp, np.log):
            temp.hermitian = remember_hermitianQ
        else:
            temp.hermitian = False
Ultimanet's avatar
Ultimanet committed
234
235
236
237
238
239
        return temp
    
    def apply_generator(self, generator):
        self.set_local_data(generator(self.distributor.local_shape))
        self.hermitian = False
            
ultimanet's avatar
ultimanet committed
240
241
242
243
244
245
    def __str__(self):
        return self.data.__str__()
    
    def __repr__(self):
        return '<distributed_data_object>\n'+self.data.__repr__()
    
Ultimanet's avatar
Ultimanet committed
246
    def __eq__(self, other):
247
        result = self.copy_empty(dtype = np.bool_)
Ultimanet's avatar
Ultimanet committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        ## Case 1: 'other' is a scalar
        ## -> make point-wise comparison
        if np.isscalar(other):
            result.set_local_data(self.get_local_data(copy = False) == other)
            return result        

        ## Case 2: 'other' is a numpy array or a distributed_data_object
        ## -> extract the local data and make point-wise comparison
        elif isinstance(other, np.ndarray) or\
        isinstance(other, distributed_data_object):
            temp_data = self.distributor.extract_local_data(other)
            result.set_local_data(self.get_local_data(copy=False) == temp_data)
            return result
        
        ## Case 3: 'other' is None
        elif other == None:
            return False
        
        ## Case 4: 'other' is something different
267
        ## -> make a numpy casting and make a recursive call
Ultimanet's avatar
Ultimanet committed
268
269
270
271
272
273
274
275
        else:
            temp_other = np.array(other)
            return self.__eq__(temp_other)
            
            
        
    
    def equal(self, other):
Ultimanet's avatar
Ultimanet committed
276
277
278
279
280
281
282
283
284
        if other is None:
            return False
        try:
            assert(self.dtype == other.dtype)
            assert(self.shape == other.shape)
            assert(self.init_args == other.init_args)
            assert(self.init_kwargs == other.init_kwargs)
            assert(self.distribution_strategy == other.distribution_strategy)
            assert(np.all(self.data == other.data))
Ultimanet's avatar
Ultimanet committed
285
        except(AssertionError, AttributeError):
Ultimanet's avatar
Ultimanet committed
286
287
288
289
290
291
292
            return False
        else:
            return True
        

            
    
293
    def __pos__(self):
294
        temp_d2o = self.copy_empty()
295
        temp_d2o.set_local_data(data = self.get_local_data(), copy = True)
296
297
        return temp_d2o
        
ultimanet's avatar
ultimanet committed
298
    def __neg__(self):
299
        temp_d2o = self.copy_empty()
300
301
        temp_d2o.set_local_data(data = self.get_local_data().__neg__(),
                                copy = True) 
ultimanet's avatar
ultimanet committed
302
303
        return temp_d2o
    
304
    def __abs__(self):
Ultimanet's avatar
Ultimanet committed
305
306
307
308
309
310
311
312
313
314
315
316
        ## translate complex dtypes
        if self.dtype == np.complex64:
            new_dtype = np.float32
        elif self.dtype == np.complex128:
            new_dtype = np.float64
        elif self.dtype == np.complex:
            new_dtype = np.float
        elif issubclass(self.dtype, np.complexfloating):
            new_dtype = np.float
        else:
            new_dtype = self.dtype
        temp_d2o = self.copy_empty(dtype = new_dtype)
317
318
        temp_d2o.set_local_data(data = self.get_local_data().__abs__(),
                                copy = True) 
319
        return temp_d2o
ultimanet's avatar
ultimanet committed
320
            
321
    def __builtin_helper__(self, operator, other, inplace=False):
Ultimanet's avatar
Ultimanet committed
322
323
324
325
326
        ## Case 1: other is not a scalar
        if not (np.isscalar(other) or np.shape(other) == (1,)):
##            if self.shape != other.shape:            
##                raise AttributeError(about._errors.cstring(
##                    "ERROR: Shapes do not match!")) 
327
            try:            
328
                hermitian_Q = (other.hermitian and self.hermitian)
329
330
            except(AttributeError):
                hermitian_Q = False
Ultimanet's avatar
Ultimanet committed
331
332
333
            ## extract the local data from the 'other' object
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
Ultimanet's avatar
Ultimanet committed
334
            
335
336
337
338
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other) or (self.dtype not in (np.complex, np.complex128,
                                                np.complex256)):
            hermitian_Q = self.hermitian
ultimanet's avatar
ultimanet committed
339
            temp_data = operator(other)
340
341
342
343
        ## Case 3: other is complex
        else:
            hermitian_Q = False
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
344
        ## write the new data into a new distributed_data_object        
345
346
347
        if inplace == True:
            temp_d2o = self
        else:
348
349
350
351
352
353
            ## use common datatype for self and other
            new_dtype = np.dtype(np.find_common_type((self.dtype,),
                                                     (temp_data.dtype,))).type
            print new_dtype                                                        
            temp_d2o = self.copy_empty(
                            dtype = new_dtype)
ultimanet's avatar
ultimanet committed
354
        temp_d2o.set_local_data(data=temp_data)
355
        temp_d2o.hermitian = hermitian_Q
ultimanet's avatar
ultimanet committed
356
        return temp_d2o
357
    """
Ultimanet's avatar
Ultimanet committed
358
    def __inplace_builtin_helper__(self, operator, other):
359
        ## Case 1: other is not a scalar
Ultimanet's avatar
Ultimanet committed
360
361
362
        if not (np.isscalar(other) or np.shape(other) == (1,)):        
            temp_data = self.distributor.extract_local_data(other)
            temp_data = operator(temp_data)
363
364
365
        ## Case 2: other is a real scalar -> preserve hermitianity
        elif np.isreal(other):
            hermitian_Q = self.hermitian
Ultimanet's avatar
Ultimanet committed
366
            temp_data = operator(other)
367
368
369
        ## Case 3: other is complex
        else:
            temp_data = operator(other)        
Ultimanet's avatar
Ultimanet committed
370
        self.set_local_data(data=temp_data)
371
        self.hermitian = hermitian_Q
Ultimanet's avatar
Ultimanet committed
372
        return self
373
    """ 
Ultimanet's avatar
Ultimanet committed
374
    
ultimanet's avatar
ultimanet committed
375
376
377
378
379
    def __add__(self, other):
        return self.__builtin_helper__(self.get_local_data().__add__, other)

    def __radd__(self, other):
        return self.__builtin_helper__(self.get_local_data().__radd__, other)
Ultimanet's avatar
Ultimanet committed
380
381

    def __iadd__(self, other):
382
383
384
        return self.__builtin_helper__(self.get_local_data().__iadd__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
385

ultimanet's avatar
ultimanet committed
386
387
388
389
390
391
392
    def __sub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__sub__, other)
    
    def __rsub__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rsub__, other)
    
    def __isub__(self, other):
393
394
395
        return self.__builtin_helper__(self.get_local_data().__isub__, 
                                               other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
396
397
398
399
        
    def __div__(self, other):
        return self.__builtin_helper__(self.get_local_data().__div__, other)
    
400
401
402
    def __truediv__(self, other):
        return self.__div__(other)
        
ultimanet's avatar
ultimanet committed
403
404
    def __rdiv__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rdiv__, other)
405
406
407
    
    def __rtruediv__(self, other):
        return self.__rdiv__(other)
ultimanet's avatar
ultimanet committed
408

Ultimanet's avatar
Ultimanet committed
409
    def __idiv__(self, other):
410
411
412
        return self.__builtin_helper__(self.get_local_data().__idiv__, 
                                               other,
                                               inplace = True)
413
    def __itruediv__(self, other):
414
415
        return self.__idiv__(other)
                                               
ultimanet's avatar
ultimanet committed
416
    def __floordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
417
418
        return self.__builtin_helper__(self.get_local_data().__floordiv__, 
                                       other)    
ultimanet's avatar
ultimanet committed
419
    def __rfloordiv__(self, other):
Ultimanet's avatar
Ultimanet committed
420
421
422
        return self.__builtin_helper__(self.get_local_data().__rfloordiv__, 
                                       other)
    def __ifloordiv__(self, other):
423
424
425
        return self.__builtin_helper__(
                    self.get_local_data().__ifloordiv__, other,
                                               inplace = True)
ultimanet's avatar
ultimanet committed
426
427
428
429
430
431
432
433
    
    def __mul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__mul__, other)
    
    def __rmul__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rmul__, other)

    def __imul__(self, other):
434
435
436
        return self.__builtin_helper__(self.get_local_data().__imul__, 
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
437

ultimanet's avatar
ultimanet committed
438
439
440
441
442
443
444
    def __pow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__pow__, other)
 
    def __rpow__(self, other):
        return self.__builtin_helper__(self.get_local_data().__rpow__, other)

    def __ipow__(self, other):
445
        return self.__builtin_helper__(self.get_local_data().__ipow__, 
446
447
                                               other,
                                               inplace = True)
Ultimanet's avatar
Ultimanet committed
448
   
449
450
    def __len__(self):
        return self.shape[0]
451
    
452
453
454
    def dim(self):
        return np.prod(self.shape)
        
455
    def vdot(self, other):
456
        other = self.distributor.extract_local_data(other)
457
458
459
460
461
        local_vdot = np.vdot(self.get_local_data(), other)
        local_vdot_list = self.distributor._allgather(local_vdot)
        global_vdot = np.sum(local_vdot_list)
        return global_vdot
            
Ultimanet's avatar
Ultimanet committed
462

463
    
ultimanet's avatar
ultimanet committed
464
    def __getitem__(self, key):
Ultimanet's avatar
Ultimanet committed
465
466
467
468
469
        ## Case 1: key is a boolean array.
        ## -> take the local data portion from key, use this for data 
        ## extraction, and then merge the result in a flat numpy array
        if isinstance(key, np.ndarray):
            found = 'ndarray'
470
            found_boolean = (key.dtype.type == np.bool_)
Ultimanet's avatar
Ultimanet committed
471
472
        elif isinstance(key, distributed_data_object):
            found = 'd2o'
473
            found_boolean = (key.dtype == np.bool_)
Ultimanet's avatar
Ultimanet committed
474
475
        else:
            found = 'other'
Ultima's avatar
Ultima committed
476
        ## TODO: transfer this into distributor:
Ultimanet's avatar
Ultimanet committed
477
478
479
480
481
482
483
484
485
486
        if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
            ## extract the data of local relevance
            local_bool_array = self.distributor.extract_local_data(key)
            local_results = self.get_local_data(copy=False)[local_bool_array]
            global_results = self.distributor._allgather(local_results)
            global_results = np.concatenate(global_results)
            return global_results            
            
        else:
            return self.get_data(key)
ultimanet's avatar
ultimanet committed
487
488
489
490
    
    def __setitem__(self, key, data):
        self.set_data(data, key)
        
491
    def _contraction_helper(self, function, **kwargs):
492
493
494
495
496
497
        local = function(self.data, **kwargs)
        local_list = self.distributor._allgather(local)
        global_ = function(local_list, axis=0)
        return global_
        
    def amin(self, **kwargs):
498
        return self._contraction_helper(np.amin, **kwargs)
499
500

    def nanmin(self, **kwargs):
501
        return self._contraction_helper(np.nanmin, **kwargs)
502
503
        
    def amax(self, **kwargs):
504
        return self._contraction_helper(np.amax, **kwargs)
505
506
    
    def nanmax(self, **kwargs):
507
        return self._contraction_helper(np.nanmax, **kwargs)
Ultimanet's avatar
Ultimanet committed
508
    
509
510
511
512
513
514
    def sum(self, **kwargs):
        return self._contraction_helper(np.sum, **kwargs)

    def prod(self, **kwargs):
        return self._contraction_helper(np.prod, **kwargs)        
        
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    def mean(self, power=1):
        ## compute the local means and the weights for the mean-mean. 
        local_mean = np.mean(self.data**power)
        local_weight = np.prod(self.data.shape)
        ## collect the local means and cast the result to a ndarray
        local_mean_weight_list = self.distributor._allgather((local_mean, 
                                                              local_weight))
        local_mean_weight_list =np.array(local_mean_weight_list)   
        ## compute the denominator for the weighted mean-mean                                                           
        global_weight = np.sum(local_mean_weight_list[:,1])
        ## compute the numerator
        numerator = np.sum(local_mean_weight_list[:,0]*\
            local_mean_weight_list[:,1])
        global_mean = numerator/global_weight
        return global_mean

    def var(self):
        mean_of_the_square = self.mean(power=2)
        square_of_the_mean = self.mean()**2
        return mean_of_the_square - square_of_the_mean
    
    def std(self):
        return np.sqrt(self.var())
        
539
540
541
542
543
544
545
546
547
548
549
#    def _argmin_argmax_flat_helper(self, function):
#        local_argmin = function(self.data)
#        local_argmin_value = self.data[np.unravel_index(local_argmin, 
#                                                        self.data.shape)]
#        globalized_local_argmin = self.distributor.globalize_flat_index(local_argmin)                                                       
#        local_argmin_list = self.distributor._allgather((local_argmin_value, 
#                                                         globalized_local_argmin))
#        local_argmin_list = np.array(local_argmin_list, dtype=[('value', int),
#                                                               ('index', int)])    
#        return local_argmin_list
#        
550
551
552
553
    def argmin_flat(self):
        local_argmin = np.argmin(self.data)
        local_argmin_value = self.data[np.unravel_index(local_argmin, 
                                                        self.data.shape)]
554
555
        globalized_local_argmin = self.distributor.globalize_flat_index(
                                                                local_argmin)                                                       
556
        local_argmin_list = self.distributor._allgather((local_argmin_value, 
557
558
559
560
561
562
                                                    globalized_local_argmin))
        local_argmin_list = np.array(local_argmin_list, dtype=[
                                        ('value', local_argmin_value.dtype),
                                        ('index', int)])    
        local_argmin_list = np.sort(local_argmin_list, 
                                    order=['value', 'index'])        
563
564
565
566
567
568
        return local_argmin_list[0][1]
    
    def argmax_flat(self):
        local_argmax = np.argmax(self.data)
        local_argmax_value = -self.data[np.unravel_index(local_argmax, 
                                                        self.data.shape)]
569
570
        globalized_local_argmax = self.distributor.globalize_flat_index(
                                                                local_argmax)                                                       
571
        local_argmax_list = self.distributor._allgather((local_argmax_value, 
572
573
574
575
576
577
                                                    globalized_local_argmax))
        local_argmax_list = np.array(local_argmax_list, dtype=[
                                        ('value', local_argmax_value.dtype),
                                        ('index', int)]) 
        local_argmax_list = np.sort(local_argmax_list, 
                                    order=['value', 'index'])        
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        return local_argmax_list[0][1]
        

    def argmin(self):    
        return np.unravel_index(self.argmin_flat(), self.shape)
    
    def argmax(self):
        return np.unravel_index(self.argmax_flat(), self.shape)
    
    def conjugate(self):
        temp_d2o = self.copy_empty()
        temp_data = np.conj(self.get_local_data())
        temp_d2o.set_local_data(temp_data)
        return temp_d2o

    
    def conj(self):
        return self.conjugate()      
        
    def median(self):
Ultimanet's avatar
Ultimanet committed
598
        about.warnings.cprint(\
599
600
601
602
            "WARNING: The current implementation of median is very expensive!")
        median = np.median(self.get_full_data())
        return median
        
603
    def iscomplex(self):
604
        temp_d2o = self.copy_empty(dtype=np.bool_)
605
606
607
608
        temp_d2o.set_local_data(np.iscomplex(self.data))
        return temp_d2o
    
    def isreal(self):
609
        temp_d2o = self.copy_empty(dtype=np.bool_)
610
611
612
        temp_d2o.set_local_data(np.isreal(self.data))
        return temp_d2o
    
613

614
615
616
617
618
619
620
621
    def all(self):
        local_all = np.all(self.get_local_data())
        global_all = self.distributor._allgather(local_all)
        return np.all(global_all)

    def any(self):
        local_any = np.any(self.get_local_data())
        global_any = self.distributor._allgather(local_any)
622
        return np.any(global_any)
623
624
625
        
    
    
626
    def set_local_data(self, data, hermitian=False, copy=True):
ultimanet's avatar
ultimanet committed
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        """
            Stores data directly in the local data attribute. No distribution 
            is done. The shape of the data must fit the local data attributes
            shape.

            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be stored in the local data attribute.
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
642
643
        self.hermitian = hermitian
        self.data = np.array(data, dtype=self.dtype, copy=copy, order='C')
ultimanet's avatar
ultimanet committed
644
    
645
    def set_data(self, data, key, hermitian=False, copy=True, *args, **kwargs):
ultimanet's avatar
ultimanet committed
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        """
            Stores the supplied data in the region which is specified by key. 
            The data is distributed according to the distribution strategy. If
            the individual nodes get different key-arguments. Their data is 
            processed one-by-one.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be stored in.                
            
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
665
        self.hermitian = hermitian
ultimanet's avatar
ultimanet committed
666
        (slices, sliceified) = self.__sliceify__(key)        
Ultimanet's avatar
Ultimanet committed
667
668
        self.distributor.disperse_data(data=self.data, 
                        to_slices = slices,
669
670
                        data_update = self.__enfold__(data, sliceified),
                        copy = copy,
Ultimanet's avatar
Ultimanet committed
671
                        *args, **kwargs)        
ultimanet's avatar
ultimanet committed
672
    
673
    def set_full_data(self, data, hermitian=False, copy = True, **kwargs):
ultimanet's avatar
ultimanet committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        """
            Distributes the supplied data to the nodes. The shape of data must 
            match the shape of the distributed_data_object.
            
            Parameters
            ----------
            data : tuple, list, numpy.ndarray 
                The data which should be distributed.
            
            Notes
            -----
            set_full_data(foo) is equivalent to set_data(foo,slice(None)) but 
            faster.
        
            Returns
            -------
            None
        
        """
Ultimanet's avatar
Ultimanet committed
693
        self.hermitian = hermitian
694
695
        self.data = self.distributor.distribute_data(data=data, copy = copy, 
                                                     **kwargs)
ultimanet's avatar
ultimanet committed
696
697
    

Ultimanet's avatar
Ultimanet committed
698
    def get_local_data(self, key=(slice(None),), copy=True):
ultimanet's avatar
ultimanet committed
699
700
701
702
703
704
705
706
707
708
709
710
711
        """
            Loads data directly from the local data attribute. No consolidation 
            is done. 

            Parameters
            ----------
            key : int, slice, tuple of int or slice
                The key which will be used to access the data. 
            
            Returns
            -------
            self.data[key] : numpy.ndarray
        
Ultimanet's avatar
Ultimanet committed
712
        """
Ultimanet's avatar
Ultimanet committed
713
714
715
716
        if copy == True:
            return self.data[key]        
        if copy == False:
            return self.data
ultimanet's avatar
ultimanet committed
717
        
718
    def get_data(self, key, **kwargs):
ultimanet's avatar
ultimanet committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        """
            Loads data from the region which is specified by key. The data is 
            consolidated according to the distribution strategy. If the 
            individual nodes get different key-arguments, they get individual
            data. 
            
            Parameters
            ----------
        
            key : int, slice, tuple of int or slice
                The key is the object which specifies the region, where data 
                will be loaded from.                 
            
            Returns
            -------
            global_data[key] : numpy.ndarray
        
        """
737
738
        (slices, sliceified) = self.__sliceify__(key)
        result = self.distributor.collect_data(self.data, slices, **kwargs)        
ultimanet's avatar
ultimanet committed
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        return self.__defold__(result, sliceified)
        
    
    
    def get_full_data(self, target_rank='all'):
        """
            Fully consolidates the distributed data. 
            
            Parameters
            ----------
            target_rank : 'all' (default), int *optional*
                If only one node should recieve the full data, it can be 
                specified here.
            
            Notes
            -----
            get_full_data() is equivalent to get_data(slice(None)) but 
            faster.
        
            Returns
            -------
            None
        """

763
764
        return self.distributor.consolidate_data(self.data, 
                                                 target_rank = target_rank)
ultimanet's avatar
ultimanet committed
765

Ultimanet's avatar
Ultimanet committed
766
767
768
769
770
771
772
    def inject(self, to_slices=(slice(None),), data=None, 
               from_slices=(slice(None),)):
        if data == None:
            return self
        
        self.distributor.inject(self.data, to_slices, data, from_slices)
        
773
774
775
776
777
778
779
780
781
782
783
    def flatten(self, inplace = False):
        flat_shape = (np.prod(self.shape),)
        temp_d2o = self.copy_empty(global_shape = flat_shape)
        flat_data = self.distributor.flatten(self.data, inplace = inplace)
        temp_d2o.set_local_data(data = flat_data)
        if inplace == True:
            self = temp_d2o
            return self
        else:
            return temp_d2o
        
Ultimanet's avatar
Ultimanet committed
784
        
785

ultimanet's avatar
ultimanet committed
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
      
    def save(self, alias, path=None, overwriteQ=True):
        
        """
            Saves a distributed_data_object to disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name for the dataset which is saved within the hdf5 file.
         
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
            
            overwriteQ : Boolean *optional*
                Specifies whether a dataset may be overwritten if it is already
                present in the given hdf5 file or not.
        """
        self.distributor.save_data(self.data, alias, path, overwriteQ)

    def load(self, alias, path=None):
        """
            Loads a distributed_data_object from disk utilizing h5py.
            
            Parameters
            ----------
            alias : string
                The name of the dataset which is loaded from the hdf5 file.
 
            path : string *optional*
                The path to the hdf5 file. If no path is given, the alias is 
                taken as filename in the current path.
        """
        self.data = self.distributor.load_data(alias, path)
           
    def __sliceify__(self, inp):
        sliceified = []
        result = []
        if isinstance(inp, tuple):
            x = inp
Ultimanet's avatar
Ultimanet committed
827
828
        elif isinstance(inp, list):
            x = tuple(inp)
ultimanet's avatar
ultimanet committed
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        else:
            x = (inp, )
        
        for i in range(len(x)):
            if isinstance(x[i], slice):
                result += [x[i], ]
                sliceified += [False, ]
            else:
                result += [slice(x[i], x[i]+1), ]
                sliceified += [True, ]
    
        return (tuple(result), sliceified)
                
                
    def __enfold__(self, in_data, sliceified):
        data = np.array(in_data, copy=False)    
        temp_shape = ()
        j=0
        for i in sliceified:
            if i == True:
                temp_shape += (1,)
850
851
852
853
854
                try:
                    if data.shape[j] == 1:
                        j +=1
                except(IndexError):
                    pass
ultimanet's avatar
ultimanet committed
855
            else:
856
857
858
859
                try:
                    temp_shape += (data.shape[j],)
                except(IndexError):
                    temp_shape += (1,)
ultimanet's avatar
ultimanet committed
860
861
862
863
864
865
                j += 1
        ## take into account that the sliceified tuple may be too short, because 
        ## of a non-exaustive list of slices
        for i in range(len(data.shape)-j):
            temp_shape += (data.shape[j],)
            j += 1
Ultimanet's avatar
Ultimanet committed
866
        
ultimanet's avatar
ultimanet committed
867
868
869
870
871
872
873
874
875
876
877
878
        return data.reshape(temp_shape)
    
    def __defold__(self, data, sliceified):
        temp_slice = ()
        for i in sliceified:
            if i == True:
                temp_slice += (0,)
            else:
                temp_slice += (slice(None),)
        return data[temp_slice]

    
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
class _distributor_factory(object):
    '''
        Comments:
          - The distributor's get_data and set_data functions MUST be 
            supplied with a tuple of slice objects. In case that there was 
            a direct integer involved, the unfolding will be done by the
            helper functions __sliceify__, __enfold__ and __defold__.
    '''
    def __init__(self):
        self.distributor_store = {}
    
    def parse_kwargs(self, strategy = None, kwargs = {}):
        return_dict = {}
        if strategy == 'not':
            pass
        if strategy == 'fftw' or strategy == 'equal':
            if kwargs.has_key('comm'):
                return_dict['comm'] = kwargs['comm']
        return return_dict
                        
899
900
    def hash_arguments(self, global_shape, dtype, distribution_strategy,
                       kwargs={}):
901
902
903
904
905
        kwargs = kwargs.copy()
        if kwargs.has_key('comm'):
            kwargs['comm'] = id(kwargs['comm'])
        kwargs['global_shape'] = global_shape        
        kwargs['dtype'] = self.dictionize_np(dtype)
906
        kwargs['distribution_strategy'] = distribution_strategy
907
        return frozenset(kwargs.items())
ultimanet's avatar
ultimanet committed
908

909
910
911
912
913
914
915
916
917
918
919
    def dictionize_np(self, x):
        dic = x.__dict__.items()
        if x is np.float:
            dic[24] = 0 
            dic[29] = 0
            dic[37] = 0
        return frozenset(dic)            
            
    def get_distributor(self, distribution_strategy, global_shape, dtype,
                        **kwargs):
        ## check if the distribution strategy is known
920
921
922
923
        
        known_distribution_strategies = ['not', 'equal']
        if found['pyfftw'] == True and found['MPI'] == True:
            known_distribution_strategies += ['fftw',]
924
925
926
927
928
929
930
931
932
        if not distribution_strategy in ['not', 'fftw', 'equal']:
            raise TypeError(about._errors.cstring(
                "ERROR: Unknown distribution strategy supplied."))
                
        ## parse the kwargs
        parsed_kwargs = self.parse_kwargs(strategy = distribution_strategy,
                                          kwargs = kwargs)
        hashed_arguments = self.hash_arguments(global_shape = global_shape,
                                               dtype = dtype,
933
                                distribution_strategy = distribution_strategy,
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
                                               kwargs = parsed_kwargs)
        #print hashed_arguments                                               
        ## check if the distributors has already been produced in the past
        if self.distributor_store.has_key(hashed_arguments):
            return self.distributor_store[hashed_arguments]
        else:                                              
            ## produce new distributor
            if distribution_strategy == 'not':
                produced_distributor = _not_distributor(
                                                    global_shape = global_shape,
                                                    dtype = dtype)
            elif distribution_strategy == 'equal':
                produced_distributor = _slicing_distributor(
                                                    slicer = _equal_slicer,
                                                    global_shape = global_shape,
                                                    dtype = dtype,
                                                    **parsed_kwargs)
            elif distribution_strategy == 'fftw':
                produced_distributor = _slicing_distributor(
                                                    slicer = _fftw_slicer,
                                                    global_shape = global_shape,
                                                    dtype = dtype,
                                                    **parsed_kwargs)                                                
            self.distributor_store[hashed_arguments] = produced_distributor                                             
            return self.distributor_store[hashed_arguments]
            
            
distributor_factory = _distributor_factory()
ultimanet's avatar
ultimanet committed
962
        
963
964
class _slicing_distributor(object):
    
ultimanet's avatar
ultimanet committed
965

966
967
    def __init__(self, slicer, global_shape=None, dtype=None, 
                 comm=MPI.COMM_WORLD):
ultimanet's avatar
ultimanet committed
968
969
        
        if comm.rank == 0:        
970
971
972
973
974
            if global_shape is None:
                raise TypeError(about._errors.cstring(
                    "ERROR: No shape supplied!"))
            else:
                self.global_shape = global_shape      
ultimanet's avatar
ultimanet committed
975
976
        else:
            self.global_shape = None
Ultimanet's avatar
Ultimanet committed
977
            
ultimanet's avatar
ultimanet committed
978
979
980
981
        self.global_shape = comm.bcast(self.global_shape, root = 0)
        self.global_shape = tuple(self.global_shape)
        
        if comm.rank == 0:        
982
983
984
                if dtype is None:        
                    raise TypeError(about._errors.cstring(
                    "ERROR: Failed setting datatype! No datatype supplied."))
ultimanet's avatar
ultimanet committed
985
                else:
986
                    self.dtype = dtype                    
ultimanet's avatar
ultimanet committed
987
988
989
        else:
            self.dtype=None
        self.dtype = comm.bcast(self.dtype, root=0)
990

ultimanet's avatar
ultimanet committed
991
        
992
        self._my_dtype_converter = _global_dtype_converter
ultimanet's avatar
ultimanet committed
993
994
        
        if not self._my_dtype_converter.known_np_Q(self.dtype):
Ultimanet's avatar
Ultimanet committed
995
            raise TypeError(about._errors.cstring(\
996
            "ERROR: The datatype "+str(self.dtype)+" is not known to mpi4py."))
ultimanet's avatar
ultimanet committed
997
998
999

        self.mpi_dtype  = self._my_dtype_converter.to_mpi(self.dtype)
        
1000
        #self._local_size = pyfftw.local_size(self.global_shape)