nifty_fft.py 25.3 KB
Newer Older
1
2
3
# -*- coding: utf-8 -*-

import numpy as np
Jait Dixit's avatar
Jait Dixit committed
4
from mpi4py import MPI
5
6
from d2o import distributed_data_object, distributor_factory, STRATEGIES
from nifty.config import about, dependency_injector as gdi
7
import nifty.nifty_utilities as utilities
Ultima's avatar
Ultima committed
8
9
10
11
12

pyfftw = gdi.get('pyfftw')
gfft = gdi.get('gfft')
gfft_dummy = gdi.get('gfft_dummy')

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

# Try to import pyfftw. If this fails fall back to gfft.
# If this fails fall back to local gfft_rg

# try:
#    import pyfftw
#    fft_machine='pyfftw'
# except(ImportError):
#    try:
#        import gfft
#        fft_machine='gfft'
#        about.infos.cprint('INFO: Using gfft')
#    except(ImportError):
#        import gfft_rg as gfft
#        fft_machine='gfft_fallback'
#        about.infos.cprint('INFO: Using builtin "plain" gfft version 0.1.0')


Ultima's avatar
Ultima committed
31
def fft_factory(fft_module_name):
32
33
34
35
36
    """
        A factory for fast-fourier-transformation objects.

        Parameters
        ----------
Jait Dixit's avatar
Jait Dixit committed
37
38
        fft_module_name : String
            Select an FFT module
39
40
41

        Returns
        -----
Jait Dixit's avatar
Jait Dixit committed
42
        fft : Returns a fft_object depending on the available packages.
43
44
45
46
        Hierarchy: pyfftw -> gfft -> built in gfft.

    """
    if fft_module_name == 'pyfftw':
Jait Dixit's avatar
Jait Dixit committed
47
        return FFTW()
48
    elif fft_module_name == 'gfft' or 'gfft_dummy':
Jait Dixit's avatar
Jait Dixit committed
49
        return GFFT(fft_module_name)
Ultima's avatar
Ultima committed
50
51
52
    else:
        raise ValueError('Given fft_module_name not known: ' +
                         str(fft_module_name))
53
54


Jait Dixit's avatar
Jait Dixit committed
55
class FFT(object):
Jait Dixit's avatar
Jait Dixit committed
56

57
58
59
    """
        A generic fft object without any implementation.
    """
Jait Dixit's avatar
Jait Dixit committed
60

Ultima's avatar
Ultima committed
61
62
    def __init__(self):
        pass
63

64
    def transform(self, val, domain, codomain, axes, **kwargs):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        """
            A generic ff-transform function.

            Parameters
            ----------
            field_val : distributed_data_object
                The value-array of the field which is supposed to
                be transformed.

            domain : nifty.rg.nifty_rg.rg_space
                The domain of the space which should be transformed.

            codomain : nifty.rg.nifty_rg.rg_space
                The taget into which the field should be transformed.
        """
theos's avatar
theos committed
80
        raise NotImplementedError
81
82


Jait Dixit's avatar
Jait Dixit committed
83
class FFTW(FFT):
Jait Dixit's avatar
Jait Dixit committed
84

Ultima's avatar
Ultima committed
85
86
87
88
89
90
91
92
    """
        The pyfftw pendant of a fft object.
    """

    def __init__(self):
        if 'pyfftw' not in gdi:
            raise ImportError("The module pyfftw is needed but not available.")

93
        self.name = 'pyfftw'
Jait Dixit's avatar
Jait Dixit committed
94
        # The plan_dict stores the FFTWTransformInfo objects which correspond
Ultima's avatar
Ultima committed
95
        # to a certain set of (field_val, domain, codomain) sets.
Jait Dixit's avatar
Jait Dixit committed
96
        self.info_dict = {}
Ultima's avatar
Ultima committed
97
98
99
100
101

        # initialize the dictionary which stores the values from
        # get_centering_mask
        self.centering_mask_dict = {}

102
103
104
        # Enable caching for pyfftw.interfaces
        pyfftw.interfaces.cache.enable()

Ultima's avatar
Ultima committed
105
106
    def get_centering_mask(self, to_center_input, dimensions_input,
                           offset_input=0):
107
        """
Ultima's avatar
Ultima committed
108
109
            Computes the mask, used to (de-)zerocenter domain and target
            fields.
110
111
112

            Parameters
            ----------
Ultima's avatar
Ultima committed
113
114
115
            to_center_input : tuple, list, numpy.ndarray
                A tuple of booleans which dimensions should be
                zero-centered.
116

Ultima's avatar
Ultima committed
117
            dimensions_input : tuple, list, numpy.ndarray
Jait Dixit's avatar
Jait Dixit committed
118
                A tuple containing the mask's desired shape.
119

Ultima's avatar
Ultima committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            offset_input : int, boolean
                Specifies whether the zero-th dimension starts with an odd
                or and even index, i.e. if it is shifted.

            Returns
            -------
            result : np.ndarray
                A 1/-1-alternating mask.
        """
        # cast input
        to_center = np.array(to_center_input)
        dimensions = np.array(dimensions_input)

        # if none of the dimensions are zero centered, return a 1
        if np.all(to_center == 0):
            return 1

        if np.all(dimensions == np.array(1)) or \
                np.all(dimensions == np.array([1])):
            return dimensions
        # The dimensions of size 1 must be sorted out for computing the
        # centering_mask. The depth of the array will be restored in the
        # end.
        size_one_dimensions = []
        temp_dimensions = []
        temp_to_center = []
        for i in range(len(dimensions)):
            if dimensions[i] == 1:
                size_one_dimensions += [True]
149
            else:
Ultima's avatar
Ultima committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                size_one_dimensions += [False]
                temp_dimensions += [dimensions[i]]
                temp_to_center += [to_center[i]]
        dimensions = np.array(temp_dimensions)
        to_center = np.array(temp_to_center)
        # cast the offset_input into the shape of to_center
        offset = np.zeros(to_center.shape, dtype=int)
        offset[0] = int(offset_input)
        # check for dimension match
        if to_center.size != dimensions.size:
            raise TypeError(
                'The length of the supplied lists does not match.')

        # build up the value memory
        # compute an identifier for the parameter set
        temp_id = tuple(
            (tuple(to_center), tuple(dimensions), tuple(offset)))
        if temp_id not in self.centering_mask_dict:
            # use np.tile in order to stack the core alternation scheme
            # until the desired format is constructed.
            core = np.fromfunction(
171
172
173
174
175
176
177
178
                lambda *args: (-1) **
                (np.tensordot(to_center,
                              args +
                              offset.reshape(offset.shape +
                                             (1,) *
                                             (np.array(args).ndim - 1)),
                              1)),
                (2,) * to_center.size)
Ultima's avatar
Ultima committed
179
180
181
182
183
184
185
186
187
188
189
            # Cast the core to the smallest integers we can get
            core = core.astype(np.int8)

            centering_mask = np.tile(core, dimensions // 2)
            # for the dimensions of odd size corresponding slices must be
            # added
            for i in range(centering_mask.ndim):
                # check if the size of the certain dimension is odd or even
                if (dimensions % 2)[i] == 0:
                    continue
                # prepare the slice object
Jait Dixit's avatar
Jait Dixit committed
190
                temp_slice = (slice(None),) * i + (slice(-2, -1, 1),) + \
Ultima's avatar
Ultima committed
191
192
193
194
195
196
197
198
199
                             (slice(None),) * (centering_mask.ndim - 1 - i)
                # append the slice to the centering_mask
                centering_mask = np.append(centering_mask,
                                           centering_mask[temp_slice],
                                           axis=i)
            # Add depth to the centering_mask where the length of a
            # dimension was one
            temp_slice = ()
            for i in range(len(size_one_dimensions)):
Jait Dixit's avatar
Jait Dixit committed
200
                if size_one_dimensions[i]:
Ultima's avatar
Ultima committed
201
202
203
204
205
206
207
                    temp_slice += (None,)
                else:
                    temp_slice += (slice(None),)
            centering_mask = centering_mask[temp_slice]
            self.centering_mask_dict[temp_id] = centering_mask
        return self.centering_mask_dict[temp_id]

208
209
    def _get_transform_info(self, domain, codomain, local_shape_info=None,
                            is_local=False, **kwargs):
Ultima's avatar
Ultima committed
210
        # generate a id-tuple which identifies the domain-codomain setting
Jait Dixit's avatar
Jait Dixit committed
211
        temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
Jait Dixit's avatar
Jait Dixit committed
212

Ultima's avatar
Ultima committed
213
        # generate the plan_and_info object if not already there
Jait Dixit's avatar
Jait Dixit committed
214
        if temp_id not in self.info_dict:
Jait Dixit's avatar
Jait Dixit committed
215
216
            if is_local:
                self.info_dict[temp_id] = FFTWLocalTransformInfo(
217
                    domain, codomain, local_shape_info,
Jait Dixit's avatar
Jait Dixit committed
218
219
220
221
                    self, **kwargs
                )
            else:
                self.info_dict[temp_id] = FFTWMPITransfromInfo(
222
                    domain, codomain, local_shape_info,
Jait Dixit's avatar
Jait Dixit committed
223
224
225
                    self, **kwargs
                )

Jait Dixit's avatar
Jait Dixit committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        return self.info_dict[temp_id]

    def _apply_mask(self, val, mask, axes):
        """
            Apply centering mask to an array.

            Parameters
            ----------
            val: distributed_data_object or numpy.ndarray
                The value-array on which the mask should be applied.

            mask: numpy.ndarray
                The mask to be applied.

            axes: tuple
                The axes which are to be transformed.

            Returns
            -------
            distributed_data_object or np.nd_array
                Mask input array by multiplying it with the mask.
        """
        # reshape mask if necessary
        if axes:
            mask = mask.reshape(
                [y if x in axes else 1
                    for x, y in enumerate(val.shape)]
            )

        return val * mask

257
    def _local_transform(self, val, info, axes, domain, codomain):
Jait Dixit's avatar
Jait Dixit committed
258
        # Apply codomain centering mask
theos's avatar
theos committed
259
260
        if reduce(lambda x, y: x+y, codomain.paradict['zerocenter']):
            temp_val = np.copy(val)
261
            val = self._apply_mask(temp_val, info.cmask_codomain, axes)
Jait Dixit's avatar
Jait Dixit committed
262

Jait Dixit's avatar
Jait Dixit committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        result = info.fftw_interface(val, axes=axes,
                                     planner_effort='FFTW_ESTIMATE')

        # Apply domain centering mask
        if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
            result = self._apply_mask(result, info.cmask_domain, axes)

        # Correct the sign if needed
        result *= info.sign

        return result

    def _mpi_transform(self, val, info, axes, domain, codomain):
        # Apply codomain centering mask
        if reduce(lambda x, y: x+y, codomain.paradict['zerocenter']):
            temp_val = np.copy(val)
            val = self._apply_mask(temp_val, info.cmask_codomain, axes)

        p = info.plan
        # Load the value into the plan
        if p.has_input:
            p.input_array[:] = val
        # Execute the plan
        p()

        if p.has_output:
            result = p.output_array
Jait Dixit's avatar
Jait Dixit committed
290
        else:
Jait Dixit's avatar
Jait Dixit committed
291
            raise RuntimeError('ERROR: PyFFTW-MPI transform failed.')
Jait Dixit's avatar
Jait Dixit committed
292
293

        # Apply domain centering mask
theos's avatar
theos committed
294
        if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
295
            result = self._apply_mask(result, info.cmask_domain, axes)
Jait Dixit's avatar
Jait Dixit committed
296
297

        # Correct the sign if needed
298
        result *= info.sign
Jait Dixit's avatar
Jait Dixit committed
299
300

        return result
Ultima's avatar
Ultima committed
301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    def _not_slicing_transform(self, val, domain, codomain, axes, **kwargs):
        temp_val = val.copy_empty(distribution_strategy='fftw')
        about.warnings.cprint('WARNING: Repacking d2o to fftw \
                                distribution strategy')
        temp_val.set_full_data(val, copy=False)

        # Recursive call to take advantage of the fact that the data
        # necessary is already present on the nodes.
        result = self.transform(temp_val, domain, codomain, axes,
                                **kwargs)

        return_val = val.copy_empty(
            distribution_strategy=val.distribution_strategy
        )
        return_val.set_full_data(result, copy=False)

        return return_val

    def _slicing_local_transform(self, val, domain, codomain, axes, **kwargs):
        current_info = self._get_transform_info(domain, codomain,
                                                is_local=True, **kwargs)

        # Compute transform for the local data
        result = self._local_transform(
            val.get_local_data(copy=False),
            current_info, axes,
            domain, codomain
        )

        # Create return object and insert results inplace
        return_val = val.copy_empty(global_shape=val.shape,
                                    dtype=codomain.dtype)
        return_val.set_local_data(data=result, copy=False)

        return return_val

    def _slicing_not_fftw_mpi_transform(self, val, domain, codomain,
                                        axes, **kwargs):
        temp_val = val.copy_empty(distribution_strategy='fftw')
        temp_val.set_full_data(val, copy=False)

        # Recursive call to transform
        result = self.transform(temp_val, domain, codomain, axes, **kwargs)

        return_val = result.copy_empty(
            distribution_strategy=val.distribution_strategy
        )
        return_val.set_full_data(data=result, copy=False)

        return return_val

353
354
355
356
357
358
359
    def _get_local_shape_info(self, comm, global_shape, distribution_strategy):
        if distribution_strategy == 'equal':
            local_slice = distributor_factory._equal_slicer(comm, global_shape)
            local_shape = np.append((local_slice[1] - local_slice[0],),
                                    global_shape[1:])
            return (local_shape, local_slice[0])

360
361
    def _slicing_fftw_mpi_transform(self, val, domain, codomain,
                                    axes, **kwargs):
362
363
364
365
366
367
368
        current_info = self._get_transform_info(
            domain, codomain,
            local_shape_info=self._get_local_shape_info(
                val.comm, val.shape, val.distribution_strategy
            ),
            **kwargs
        )
369
370

        return_val = val.copy_empty(global_shape=val.shape,
371
                                    dtype=codomain.dtype)
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

        # Extract local data
        local_val = val.get_local_data(copy=False)

        # Create temporary storage for slices
        temp_val = None

        # If axes tuple includes all axes, set it to None
        if axes is not None:
            if set(axes) == set(range(len(val.shape))):
                axes = None

        for slice_list in utilities.get_slice_list(local_val.shape, axes):
            if slice_list == [slice(None, None)]:
                inp = local_val
            else:
                if temp_val is None:
                    temp_val = np.empty_like(local_val)
                inp = local_val[slice_list]

            # This is in order to make FFTW behave properly when slicing input
            # over MPI ranks when the input is 1-dimensional. The default
            # behaviour is to optimize to take advantage of byte-alignment,
            # which doesn't match the slicing strategy for multi-dimensional
            # data.
            original_shape = None
            if len(inp.shape) == 1:
                original_shape = inp.shape
                inp = inp.reshape(inp.shape[0], 1)

            result = self._mpi_transform(inp, current_info, axes,
                                         domain, codomain)

            if slice_list == [slice(None, None)]:
                temp_val = result
            else:
                # Reverting to the original shape i.e. before the input was
                # augmented with 1 to make FFTW behave properly.
                if original_shape is not None:
                    result = result.reshape(original_shape)
                temp_val[slice_list] = result

        return_val.set_local_data(data=temp_val, copy=False)

        return return_val

theos's avatar
theos committed
418
    def transform(self, val, domain, codomain, axes=None, **kwargs):
Ultima's avatar
Ultima committed
419
420
421
422
423
424
425
426
427
428
429
430
431
        """
            The pyfftw transform function.

            Parameters
            ----------
            val : distributed_data_object or numpy.ndarray
                The value-array of the field which is supposed to
                be transformed.

            domain : nifty.rg.nifty_rg.rg_space
                The domain of the space which should be transformed.

            codomain : nifty.rg.nifty_rg.rg_space
Jait Dixit's avatar
Jait Dixit committed
432
433
434
435
                The target into which the field should be transformed.

            axes: tuple, None
                The axes which should be transformed.
Ultima's avatar
Ultima committed
436
437
438
439
440
441

            **kwargs : *optional*
                Further kwargs are passed to the create_mpi_plan routine.

            Returns
            -------
Jait Dixit's avatar
Jait Dixit committed
442
            result : np.ndarray or distributed_data_object
Ultima's avatar
Ultima committed
443
444
                Fourier-transformed pendant of the input field.
        """
Jait Dixit's avatar
Jait Dixit committed
445
446
447
448
449
        # Check if the axes provided are valid given the shape
        if axes is not None and \
                not all(axis in range(len(val.shape)) for axis in axes):
            raise ValueError("ERROR: Provided axes does not match array shape")

Jait Dixit's avatar
Jait Dixit committed
450
451
        # If the input is a numpy array we transform it locally
        if not isinstance(val, distributed_data_object):
Jait Dixit's avatar
Jait Dixit committed
452
453
454
            # Cast to a np.ndarray
            temp_val = np.asarray(val)

455
456
457
            current_info = self._get_transform_info(domain, codomain,
                                                    is_local=True,
                                                    **kwargs)
Jait Dixit's avatar
Jait Dixit committed
458

Jait Dixit's avatar
Jait Dixit committed
459
            # local transform doesn't apply transforms inplace
Jait Dixit's avatar
Jait Dixit committed
460
            return_val = self._local_transform(temp_val, current_info, axes,
theos's avatar
theos committed
461
                                               domain, codomain)
Ultima's avatar
Ultima committed
462
        else:
Jait Dixit's avatar
Jait Dixit committed
463
464
465
            if val.comm is not MPI.COMM_WORLD:
                raise RuntimeError('ERROR: Input array uses an unsupported \
                                   comm object')
Jait Dixit's avatar
Jait Dixit committed
466

Jait Dixit's avatar
Jait Dixit committed
467
468
469
470
            if val.distribution_strategy in STRATEGIES['slicing']:
                if axes is None or set(axes) == set(range(len(val.shape))) \
                        or 0 in axes:
                    if val.distribution_strategy != 'fftw':
471
472
473
474
                        return_val = \
                            self._slicing_not_fftw_mpi_transform(
                                val, domain, codomain, axes, **kwargs
                            )
Jait Dixit's avatar
Jait Dixit committed
475
                    else:
476
477
478
                        return_val = self._slicing_fftw_mpi_transform(
                            val, domain, codomain, axes, **kwargs
                        )
Jait Dixit's avatar
Jait Dixit committed
479
                else:
480
481
                    return_val = self._slicing_local_transform(
                        val, domain, codomain, axes, **kwargs
Jait Dixit's avatar
Jait Dixit committed
482
                    )
483
            else:
484
485
                return_val = self._not_slicing_transform(
                    val, domain, codomain, axes, **kwargs
486
                )
Jait Dixit's avatar
Jait Dixit committed
487
488
489
490
491

            # If domain is purely real, the result of the FFT is hermitian
            if domain.paradict['complexity'] == 0:
                return_val.hermitian = True

Jait Dixit's avatar
Jait Dixit committed
492
        return return_val
493

Ultima's avatar
Ultima committed
494

Jait Dixit's avatar
Jait Dixit committed
495
496
class FFTWTransformInfo(object):

497
498
    def __init__(self, domain, codomain, local_shape_info,
                 fftw_context, **kwargs):
Ultima's avatar
Ultima committed
499
500
501
        if pyfftw is None:
            raise ImportError("The module pyfftw is needed but not available.")

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        # When the domain being transformed is not split across ranks, the
        # mask will then have the same shape as the domain. The offset
        # is set to False since every node will have index starting from 0. In
        # the other case, we use the supplied local_shape_info to get the
        # local_shape and offset
        if local_shape_info is None:
            self.cmask_domain = fftw_context.get_centering_mask(
                domain.paradict['zerocenter'],
                domain.get_shape(),
                False
            )
        else:
            self.cmask_domain = fftw_context.get_centering_mask(
                domain.paradict['zerocenter'],
                local_shape_info[0],
                local_shape_info[1] % 2
            )
Ultima's avatar
Ultima committed
519

520
521
522
523
524
525
526
527
528
529
530
531
        if local_shape_info is None:
            self.cmask_codomain = fftw_context.get_centering_mask(
                codomain.paradict['zerocenter'],
                codomain.get_shape(),
                False
            )
        else:
            self.cmask_domain = fftw_context.get_centering_mask(
                codomain.paradict['zerocenter'],
                local_shape_info[0],
                local_shape_info[1] % 2
            )
Jait Dixit's avatar
Jait Dixit committed
532
533
534
535
536
537
538
539
540
541
542
543

        # If both domain and codomain are zero-centered the result,
        # will get a global minus. Store the sign to correct it.
        self.sign = (-1) ** np.sum(
            np.array(domain.paradict['zerocenter']) *
            np.array(codomain.paradict['zerocenter']) *
            (np.array(domain.get_shape()) // 2 % 2)
        )

    @property
    def cmask_domain(self):
        return self._domain_centering_mask
Ultima's avatar
Ultima committed
544

Jait Dixit's avatar
Jait Dixit committed
545
546
547
    @cmask_domain.setter
    def cmask_domain(self, cmask):
        self._domain_centering_mask = cmask
Ultima's avatar
Ultima committed
548

Jait Dixit's avatar
Jait Dixit committed
549
550
551
    @property
    def cmask_codomain(self):
        return self._codomain_centering_mask
Ultima's avatar
Ultima committed
552

Jait Dixit's avatar
Jait Dixit committed
553
554
555
    @cmask_codomain.setter
    def cmask_codomain(self, cmask):
        self._codomain_centering_mask = cmask
Ultima's avatar
Ultima committed
556

Jait Dixit's avatar
Jait Dixit committed
557
    @property
Jait Dixit's avatar
Jait Dixit committed
558
559
560
561
562
563
    def sign(self):
        return self._sign

    @sign.setter
    def sign(self, sign):
        self._sign = sign
Ultima's avatar
Ultima committed
564

Jait Dixit's avatar
Jait Dixit committed
565
566

class FFTWLocalTransformInfo(FFTWTransformInfo):
567
568
569
570
571
    def __init__(self, domain, codomain, local_shape_info,
                 fftw_context, **kwargs):
        super(FFTWLocalTransformInfo, self).__init__(
            domain, codomain, local_shape_info, fftw_context, **kwargs
        )
Jait Dixit's avatar
Jait Dixit committed
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        if codomain.harmonic:
            self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
        else:
            self._fftw_interface = pyfftw.interfaces.numpy_fftn.ifftn

    @property
    def fftw_interface(self):
        return self._fftw_interface

    @fftw_interface.setter
    def fftw_interface(self, interface):
        about.warnings.cprint('WARNING: FFTWLocalTransformInfo fftw_interface \
                               cannot be modified')


class FFTWMPITransfromInfo(FFTWTransformInfo):
588
589
590
591
592
    def __init__(self, domain, codomain, local_shape_info,
                 fftw_context, **kwargs):
        super(FFTWMPITransfromInfo, self).__init__(
            domain, codomain, local_shape_info, fftw_context, **kwargs
        )
593
594
595
596
597
598
599
        # When the domain is 1-dimensional, reshape it so that it can
        # accept input which is also augmented by 1.
        if len(domain.get_shape()) == 1:
            shape = (domain.get_shape()[0], 1)
        else:
            shape = domain.get_shape()

Jait Dixit's avatar
Jait Dixit committed
600
        self._plan = pyfftw.create_mpi_plan(
601
            input_shape=shape,
Jait Dixit's avatar
Jait Dixit committed
602
603
604
605
606
607
            input_dtype='complex128',
            output_dtype='complex128',
            direction='FFTW_FORWARD' if codomain.harmonic else 'FFTW_BACKWARD',
            flags=["FFTW_ESTIMATE"],
            **kwargs
        )
Ultima's avatar
Ultima committed
608

Jait Dixit's avatar
Jait Dixit committed
609
610
611
    @property
    def plan(self):
        return self._plan
Ultima's avatar
Ultima committed
612

Jait Dixit's avatar
Jait Dixit committed
613
614
    @plan.setter
    def plan(self, plan):
Jait Dixit's avatar
Jait Dixit committed
615
616
        about.warnings.cprint('WARNING: FFTWMPITransfromInfo plan \
                               cannot be modified')
Ultima's avatar
Ultima committed
617

618

Jait Dixit's avatar
Jait Dixit committed
619
class GFFT(FFT):
Jait Dixit's avatar
Jait Dixit committed
620

Ultima's avatar
Ultima committed
621
622
623
624
625
    """
        The gfft pendant of a fft object.

        Parameters
        ----------
Jait Dixit's avatar
Jait Dixit committed
626
627
        fft_module_name : String
            Switch between the gfft module used: 'gfft' and 'gfft_dummy'
Ultima's avatar
Ultima committed
628
629

    """
Jait Dixit's avatar
Jait Dixit committed
630

Ultima's avatar
Ultima committed
631
    def __init__(self, fft_module_name):
632
        self.name = fft_module_name
Ultima's avatar
Ultima committed
633
634
635
636
637
        self.fft_machine = gdi.get(fft_module_name)
        if self.fft_machine is None:
            raise ImportError(
                "The gfft(_dummy)-module is needed but not available.")

theos's avatar
theos committed
638
    def transform(self, val, domain, codomain, axes=None, **kwargs):
639
        """
Ultima's avatar
Ultima committed
640
            The gfft transform function.
641
642
643

            Parameters
            ----------
Ultima's avatar
Ultima committed
644
645
646
            val : numpy.ndarray or distributed_data_object
                The value-array of the field which is supposed to
                be transformed.
647

Ultima's avatar
Ultima committed
648
649
            domain : nifty.rg.nifty_rg.rg_space
                The domain of the space which should be transformed.
650

Ultima's avatar
Ultima committed
651
            codomain : nifty.rg.nifty_rg.rg_space
Jait Dixit's avatar
Jait Dixit committed
652
                The target into which the field should be transformed.
653

Jait Dixit's avatar
Jait Dixit committed
654
655
656
            axes : None or tuple
                The axes which should be transformed.

Ultima's avatar
Ultima committed
657
658
659
660
661
            **kwargs : *optional*
                Further kwargs are not processed.

            Returns
            -------
Jait Dixit's avatar
Jait Dixit committed
662
            result : np.ndarray or distributed_data_object
Ultima's avatar
Ultima committed
663
664
                Fourier-transformed pendant of the input field.
        """
Jait Dixit's avatar
Jait Dixit committed
665
666
667
668
669
        # Check if the axes provided are valid given the shape
        if axes is not None and \
                not all(axis in range(len(val.shape)) for axis in axes):
            raise ValueError("ERROR: Provided axes does not match array shape")

670
671
        # GFFT doesn't accept d2o objects as input. Consolidate data from
        # all nodes into numpy.ndarray before proceeding.
Ultima's avatar
Ultima committed
672
        if isinstance(val, distributed_data_object):
theos's avatar
theos committed
673
            temp_inp = val.get_full_data()
Ultima's avatar
Ultima committed
674
        else:
theos's avatar
theos committed
675
            temp_inp = val
676

677
        # Cast input datatype to codomain's dtype
theos's avatar
theos committed
678
        temp_inp = temp_inp.astype(np.complex128, copy=False)
679

680
        # Array for storing the result
theos's avatar
theos committed
681
682
683
        return_val = None

        for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
684

685
686
            # don't copy the whole data array
            if slice_list == [slice(None, None)]:
theos's avatar
theos committed
687
                inp = temp_inp
688
            else:
theos's avatar
theos committed
689
690
691
692
                # initialize the return_val object if needed
                if return_val is None:
                    return_val = np.empty_like(temp_inp)
                inp = temp_inp[slice_list]
693

694
695
            inp = self.fft_machine.gfft(
                inp,
Jait Dixit's avatar
Jait Dixit committed
696
697
                in_ax=[],
                out_ax=[],
698
                ftmachine='fft' if codomain.harmonic else 'ifft',
Jait Dixit's avatar
Jait Dixit committed
699
700
                in_zero_center=map(bool, domain.paradict['zerocenter']),
                out_zero_center=map(bool, codomain.paradict['zerocenter']),
701
702
703
                enforce_hermitian_symmetry=bool(
                    codomain.paradict['complexity']
                ),
Jait Dixit's avatar
Jait Dixit committed
704
705
                W=-1,
                alpha=-1,
Jait Dixit's avatar
Jait Dixit committed
706
707
                verbose=False
            )
theos's avatar
theos committed
708
709
710
711
            if slice_list == [slice(None, None)]:
                return_val = inp
            else:
                return_val[slice_list] = inp
712

713
        if isinstance(val, distributed_data_object):
714
            new_val = val.copy_empty(dtype=codomain.dtype)
theos's avatar
theos committed
715
            new_val.set_full_data(return_val, copy=False)
716
717
718
719
            # If the values living in domain are purely real, the result of
            # the fft is hermitian
            if domain.paradict['complexity'] == 0:
                new_val.hermitian = True
720
            return_val = new_val
721
        else:
theos's avatar
theos committed
722
            return_val = return_val.astype(codomain.dtype, copy=False)
Ultima's avatar
Ultima committed
723

724
        return return_val