nifty_fft.py 24.6 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import warnings

5
import numpy as np
6
7
8
9
from d2o import distributed_data_object,\
                STRATEGIES
from nifty.config import about,\
                         dependency_injector as gdi
10
import nifty.nifty_utilities as utilities
Ultima's avatar
Ultima committed
11
12
13
14
15

pyfftw = gdi.get('pyfftw')
gfft = gdi.get('gfft')
gfft_dummy = gdi.get('gfft_dummy')

16

Ultima's avatar
Ultima committed
17
def fft_factory(fft_module_name):
18
19
20
21
22
    """
        A factory for fast-fourier-transformation objects.

        Parameters
        ----------
Jait Dixit's avatar
Jait Dixit committed
23
24
        fft_module_name : String
            Select an FFT module
25
26
27

        Returns
        -----
Jait Dixit's avatar
Jait Dixit committed
28
        fft : Returns a fft_object depending on the available packages.
29
30
31
32
        Hierarchy: pyfftw -> gfft -> built in gfft.

    """
    if fft_module_name == 'pyfftw':
Jait Dixit's avatar
Jait Dixit committed
33
        return FFTW()
34
    elif fft_module_name == 'gfft' or 'gfft_dummy':
Jait Dixit's avatar
Jait Dixit committed
35
        return GFFT(fft_module_name)
Ultima's avatar
Ultima committed
36
37
38
    else:
        raise ValueError('Given fft_module_name not known: ' +
                         str(fft_module_name))
39
40


Jait Dixit's avatar
Jait Dixit committed
41
class FFT(object):
Jait Dixit's avatar
Jait Dixit committed
42

43
44
45
    """
        A generic fft object without any implementation.
    """
Jait Dixit's avatar
Jait Dixit committed
46

Ultima's avatar
Ultima committed
47
48
    def __init__(self):
        pass
49

50
    def transform(self, val, domain, codomain, axes, **kwargs):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        """
            A generic ff-transform function.

            Parameters
            ----------
            field_val : distributed_data_object
                The value-array of the field which is supposed to
                be transformed.

            domain : nifty.rg.nifty_rg.rg_space
                The domain of the space which should be transformed.

            codomain : nifty.rg.nifty_rg.rg_space
                The taget into which the field should be transformed.
        """
theos's avatar
theos committed
66
        raise NotImplementedError
67
68


Jait Dixit's avatar
Jait Dixit committed
69
class FFTW(FFT):
Jait Dixit's avatar
Jait Dixit committed
70

Ultima's avatar
Ultima committed
71
72
73
74
75
76
77
78
    """
        The pyfftw pendant of a fft object.
    """

    def __init__(self):
        if 'pyfftw' not in gdi:
            raise ImportError("The module pyfftw is needed but not available.")

79
        self.name = 'pyfftw'
Jait Dixit's avatar
Jait Dixit committed
80
        # The plan_dict stores the FFTWTransformInfo objects which correspond
Ultima's avatar
Ultima committed
81
        # to a certain set of (field_val, domain, codomain) sets.
Jait Dixit's avatar
Jait Dixit committed
82
        self.info_dict = {}
Ultima's avatar
Ultima committed
83
84
85
86
87

        # initialize the dictionary which stores the values from
        # get_centering_mask
        self.centering_mask_dict = {}

88
89
90
        # Enable caching for pyfftw.interfaces
        pyfftw.interfaces.cache.enable()

Ultima's avatar
Ultima committed
91
    def get_centering_mask(self, to_center_input, dimensions_input,
92
                           offset_input=False):
93
        """
Ultima's avatar
Ultima committed
94
95
            Computes the mask, used to (de-)zerocenter domain and target
            fields.
96
97
98

            Parameters
            ----------
Ultima's avatar
Ultima committed
99
100
101
            to_center_input : tuple, list, numpy.ndarray
                A tuple of booleans which dimensions should be
                zero-centered.
102

Ultima's avatar
Ultima committed
103
            dimensions_input : tuple, list, numpy.ndarray
Jait Dixit's avatar
Jait Dixit committed
104
                A tuple containing the mask's desired shape.
105

Ultima's avatar
Ultima committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            offset_input : int, boolean
                Specifies whether the zero-th dimension starts with an odd
                or and even index, i.e. if it is shifted.

            Returns
            -------
            result : np.ndarray
                A 1/-1-alternating mask.
        """
        # cast input
        to_center = np.array(to_center_input)
        dimensions = np.array(dimensions_input)

        # if none of the dimensions are zero centered, return a 1
        if np.all(to_center == 0):
            return 1

        if np.all(dimensions == np.array(1)) or \
                np.all(dimensions == np.array([1])):
            return dimensions
        # The dimensions of size 1 must be sorted out for computing the
        # centering_mask. The depth of the array will be restored in the
        # end.
        size_one_dimensions = []
        temp_dimensions = []
        temp_to_center = []
        for i in range(len(dimensions)):
            if dimensions[i] == 1:
                size_one_dimensions += [True]
135
            else:
Ultima's avatar
Ultima committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                size_one_dimensions += [False]
                temp_dimensions += [dimensions[i]]
                temp_to_center += [to_center[i]]
        dimensions = np.array(temp_dimensions)
        to_center = np.array(temp_to_center)
        # cast the offset_input into the shape of to_center
        offset = np.zeros(to_center.shape, dtype=int)
        offset[0] = int(offset_input)
        # check for dimension match
        if to_center.size != dimensions.size:
            raise TypeError(
                'The length of the supplied lists does not match.')

        # build up the value memory
        # compute an identifier for the parameter set
        temp_id = tuple(
            (tuple(to_center), tuple(dimensions), tuple(offset)))
        if temp_id not in self.centering_mask_dict:
            # use np.tile in order to stack the core alternation scheme
            # until the desired format is constructed.
            core = np.fromfunction(
157
158
159
160
161
162
163
164
                lambda *args: (-1) **
                (np.tensordot(to_center,
                              args +
                              offset.reshape(offset.shape +
                                             (1,) *
                                             (np.array(args).ndim - 1)),
                              1)),
                (2,) * to_center.size)
Ultima's avatar
Ultima committed
165
166
167
168
169
170
171
172
173
174
175
            # Cast the core to the smallest integers we can get
            core = core.astype(np.int8)

            centering_mask = np.tile(core, dimensions // 2)
            # for the dimensions of odd size corresponding slices must be
            # added
            for i in range(centering_mask.ndim):
                # check if the size of the certain dimension is odd or even
                if (dimensions % 2)[i] == 0:
                    continue
                # prepare the slice object
Jait Dixit's avatar
Jait Dixit committed
176
                temp_slice = (slice(None),) * i + (slice(-2, -1, 1),) + \
Ultima's avatar
Ultima committed
177
178
179
180
181
182
183
184
185
                             (slice(None),) * (centering_mask.ndim - 1 - i)
                # append the slice to the centering_mask
                centering_mask = np.append(centering_mask,
                                           centering_mask[temp_slice],
                                           axis=i)
            # Add depth to the centering_mask where the length of a
            # dimension was one
            temp_slice = ()
            for i in range(len(size_one_dimensions)):
Jait Dixit's avatar
Jait Dixit committed
186
                if size_one_dimensions[i]:
Ultima's avatar
Ultima committed
187
188
189
190
191
192
193
                    temp_slice += (None,)
                else:
                    temp_slice += (slice(None),)
            centering_mask = centering_mask[temp_slice]
            self.centering_mask_dict[temp_id] = centering_mask
        return self.centering_mask_dict[temp_id]

194
195
    def _get_transform_info(self, domain, codomain, local_shape,
                            local_offset_Q, is_local, **kwargs):
Ultima's avatar
Ultima committed
196
        # generate a id-tuple which identifies the domain-codomain setting
Jait Dixit's avatar
Jait Dixit committed
197
        temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
Jait Dixit's avatar
Jait Dixit committed
198

Ultima's avatar
Ultima committed
199
        # generate the plan_and_info object if not already there
Jait Dixit's avatar
Jait Dixit committed
200
        if temp_id not in self.info_dict:
Jait Dixit's avatar
Jait Dixit committed
201
202
            if is_local:
                self.info_dict[temp_id] = FFTWLocalTransformInfo(
203
                    domain, codomain, local_shape, local_offset_Q,
Jait Dixit's avatar
Jait Dixit committed
204
205
206
207
                    self, **kwargs
                )
            else:
                self.info_dict[temp_id] = FFTWMPITransfromInfo(
208
                    domain, codomain, local_shape, local_offset_Q,
Jait Dixit's avatar
Jait Dixit committed
209
210
211
                    self, **kwargs
                )

Jait Dixit's avatar
Jait Dixit committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        return self.info_dict[temp_id]

    def _apply_mask(self, val, mask, axes):
        """
            Apply centering mask to an array.

            Parameters
            ----------
            val: distributed_data_object or numpy.ndarray
                The value-array on which the mask should be applied.

            mask: numpy.ndarray
                The mask to be applied.

            axes: tuple
                The axes which are to be transformed.

            Returns
            -------
            distributed_data_object or np.nd_array
                Mask input array by multiplying it with the mask.
        """
        # reshape mask if necessary
        if axes:
            mask = mask.reshape(
                [y if x in axes else 1
                    for x, y in enumerate(val.shape)]
            )

        return val * mask

243
    def _atomic_mpi_transform(self, val, info, axes, domain, codomain):
Jait Dixit's avatar
Jait Dixit committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        # Apply codomain centering mask
        if reduce(lambda x, y: x+y, codomain.paradict['zerocenter']):
            temp_val = np.copy(val)
            val = self._apply_mask(temp_val, info.cmask_codomain, axes)

        p = info.plan
        # Load the value into the plan
        if p.has_input:
            p.input_array[:] = val
        # Execute the plan
        p()

        if p.has_output:
            result = p.output_array
Jait Dixit's avatar
Jait Dixit committed
258
        else:
259
            return None
Jait Dixit's avatar
Jait Dixit committed
260
261

        # Apply domain centering mask
theos's avatar
theos committed
262
        if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
263
            result = self._apply_mask(result, info.cmask_domain, axes)
Jait Dixit's avatar
Jait Dixit committed
264
265

        # Correct the sign if needed
266
        result *= info.sign
Jait Dixit's avatar
Jait Dixit committed
267
268

        return result
Ultima's avatar
Ultima committed
269

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    def _local_transform(self, val, domain, codomain, axes, **kwargs):
        ####
        # val must be numpy array or d2o with slicing distributor
        ###

        local_offset_Q = False
        try:
            local_val = val.get_local_data(copy=False),
            if axes is None or 0 in axes:
                local_offset_Q = val.distributor.local_shape[0] % 2
        except(AttributeError):
            local_val = val

        current_info = self._get_transform_info(domain,
                                                codomain,
                                                local_shape=local_val.shape,
                                                local_offset_Q=local_offset_Q,
                                                is_local=True,
                                                **kwargs)
289

290
291
292
293
294
        # Apply codomain centering mask
        if reduce(lambda x, y: x+y, codomain.paradict['zerocenter']):
            temp_val = np.copy(local_val)
            local_val = self._apply_mask(temp_val, current_info.cmask_codomain,
                                         axes)
295

296
297
298
299
        local_result = current_info.fftw_interface(
                                            local_val,
                                            axes=axes,
                                            planner_effort='FFTW_ESTIMATE')
300

301
302
303
304
        # Apply domain centering mask
        if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
            local_result = self._apply_mask(local_result,
                                            current_info.cmask_domain, axes)
305

306
307
308
        # Correct the sign if needed
        if current_info.sign != 1:
            local_result *= current_info.sign
309

310
311
312
313
314
315
316
        try:
            # Create return object and insert results inplace
            return_val = val.copy_empty(global_shape=val.shape,
                                        dtype=codomain.dtype)
            return_val.set_local_data(data=local_result, copy=False)
        except(AttributeError):
            return_val = local_result
317
318
319

        return return_val

320
321
    def _repack_to_fftw_and_transform(self, val, domain, codomain,
                                      axes, **kwargs):
322
        temp_val = val.copy_empty(distribution_strategy='fftw')
323
324
        about.warnings.cprint('WARNING: Repacking d2o to fftw \
                                distribution strategy')
325
326
327
328
329
330
        temp_val.set_full_data(val, copy=False)

        # Recursive call to transform
        result = self.transform(temp_val, domain, codomain, axes, **kwargs)

        return_val = result.copy_empty(
331
            distribution_strategy=val.distribution_strategy)
332
333
334
335
        return_val.set_full_data(data=result, copy=False)

        return return_val

336
    def _mpi_transform(self, val, domain, codomain, axes, **kwargs):
337

338
339
340
341
342
343
344
345
346
347
348
349
350
351
        if axes is None or 0 in axes:
            local_offset_list = np.cumsum(np.concatenate(
                                    [[0, ],
                                     val.distributor.all_local_slices[:, 2]]))
            local_offset_Q = bool(
                local_offset_list[val.distributor.comm.rank] % 2)
        else:
            local_offset_Q = False
        current_info = self._get_transform_info(domain,
                                                codomain,
                                                local_shape=val.local_shape,
                                                local_offset_Q=local_offset_Q,
                                                is_local=False,
                                                **kwargs)
352
        return_val = val.copy_empty(global_shape=val.shape,
353
                                    dtype=codomain.dtype)
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

        # Extract local data
        local_val = val.get_local_data(copy=False)

        # Create temporary storage for slices
        temp_val = None

        # If axes tuple includes all axes, set it to None
        if axes is not None:
            if set(axes) == set(range(len(val.shape))):
                axes = None

        for slice_list in utilities.get_slice_list(local_val.shape, axes):
            if slice_list == [slice(None, None)]:
                inp = local_val
            else:
                if temp_val is None:
                    temp_val = np.empty_like(local_val)
                inp = local_val[slice_list]

            # This is in order to make FFTW behave properly when slicing input
            # over MPI ranks when the input is 1-dimensional. The default
            # behaviour is to optimize to take advantage of byte-alignment,
            # which doesn't match the slicing strategy for multi-dimensional
            # data.
            original_shape = None
            if len(inp.shape) == 1:
                original_shape = inp.shape
                inp = inp.reshape(inp.shape[0], 1)

384
385
386
387
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                result = self._atomic_mpi_transform(inp, current_info, axes,
                                                    domain, codomain)
388

389
390
391
            if result is None:
                temp_val = np.empty_like(local_val)
            elif slice_list == [slice(None, None)]:
392
393
394
395
396
397
398
399
400
401
402
403
                temp_val = result
            else:
                # Reverting to the original shape i.e. before the input was
                # augmented with 1 to make FFTW behave properly.
                if original_shape is not None:
                    result = result.reshape(original_shape)
                temp_val[slice_list] = result

        return_val.set_local_data(data=temp_val, copy=False)

        return return_val

theos's avatar
theos committed
404
    def transform(self, val, domain, codomain, axes=None, **kwargs):
Ultima's avatar
Ultima committed
405
406
407
408
409
410
411
412
413
414
415
416
417
        """
            The pyfftw transform function.

            Parameters
            ----------
            val : distributed_data_object or numpy.ndarray
                The value-array of the field which is supposed to
                be transformed.

            domain : nifty.rg.nifty_rg.rg_space
                The domain of the space which should be transformed.

            codomain : nifty.rg.nifty_rg.rg_space
Jait Dixit's avatar
Jait Dixit committed
418
419
420
421
                The target into which the field should be transformed.

            axes: tuple, None
                The axes which should be transformed.
Ultima's avatar
Ultima committed
422
423
424
425
426
427

            **kwargs : *optional*
                Further kwargs are passed to the create_mpi_plan routine.

            Returns
            -------
Jait Dixit's avatar
Jait Dixit committed
428
            result : np.ndarray or distributed_data_object
Ultima's avatar
Ultima committed
429
430
                Fourier-transformed pendant of the input field.
        """
Jait Dixit's avatar
Jait Dixit committed
431
432
433
434
435
        # Check if the axes provided are valid given the shape
        if axes is not None and \
                not all(axis in range(len(val.shape)) for axis in axes):
            raise ValueError("ERROR: Provided axes does not match array shape")

Jait Dixit's avatar
Jait Dixit committed
436
437
        # If the input is a numpy array we transform it locally
        if not isinstance(val, distributed_data_object):
Jait Dixit's avatar
Jait Dixit committed
438
439
440
            # Cast to a np.ndarray
            temp_val = np.asarray(val)

441
442
443
            current_info = self._get_transform_info(domain, codomain,
                                                    is_local=True,
                                                    **kwargs)
Jait Dixit's avatar
Jait Dixit committed
444

Jait Dixit's avatar
Jait Dixit committed
445
            # local transform doesn't apply transforms inplace
Jait Dixit's avatar
Jait Dixit committed
446
            return_val = self._local_transform(temp_val, current_info, axes,
theos's avatar
theos committed
447
                                               domain, codomain)
Ultima's avatar
Ultima committed
448
        else:
Jait Dixit's avatar
Jait Dixit committed
449
            if val.distribution_strategy in STRATEGIES['slicing']:
450
                if axes is None or 0 in axes:
Jait Dixit's avatar
Jait Dixit committed
451
                    if val.distribution_strategy != 'fftw':
452
                        return_val = \
453
                            self._repack_to_fftw_and_transform(
454
455
                                val, domain, codomain, axes, **kwargs
                            )
Jait Dixit's avatar
Jait Dixit committed
456
                    else:
457
                        return_val = self._mpi_transform(
458
459
                            val, domain, codomain, axes, **kwargs
                        )
Jait Dixit's avatar
Jait Dixit committed
460
                else:
461
                    return_val = self._local_transform(
462
                        val, domain, codomain, axes, **kwargs
Jait Dixit's avatar
Jait Dixit committed
463
                    )
464
            else:
465
                return_val = self._repack_to_fftw_and_transform(
466
                    val, domain, codomain, axes, **kwargs
467
                )
Jait Dixit's avatar
Jait Dixit committed
468
469
470
471
472

            # If domain is purely real, the result of the FFT is hermitian
            if domain.paradict['complexity'] == 0:
                return_val.hermitian = True

Jait Dixit's avatar
Jait Dixit committed
473
        return return_val
474

Ultima's avatar
Ultima committed
475

Jait Dixit's avatar
Jait Dixit committed
476
477
class FFTWTransformInfo(object):

478
    def __init__(self, domain, codomain, local_shape, local_offset_Q,
479
                 fftw_context, **kwargs):
Ultima's avatar
Ultima committed
480
481
482
        if pyfftw is None:
            raise ImportError("The module pyfftw is needed but not available.")

483
484
485
486
        self.cmask_domain = fftw_context.get_centering_mask(
            domain.paradict['zerocenter'],
            local_shape,
            local_offset_Q)
Ultima's avatar
Ultima committed
487

488
489
490
491
        self.cmask_codomain = fftw_context.get_centering_mask(
            codomain.paradict['zerocenter'],
            local_shape,
            local_offset_Q)
Jait Dixit's avatar
Jait Dixit committed
492
493
494

        # If both domain and codomain are zero-centered the result,
        # will get a global minus. Store the sign to correct it.
495
496
497
        self.sign = (-1) ** np.sum(np.array(domain.paradict['zerocenter']) *
                                   np.array(codomain.paradict['zerocenter']) *
                                   (np.array(domain.get_shape()) // 2 % 2))
Jait Dixit's avatar
Jait Dixit committed
498
499
500
501

    @property
    def cmask_domain(self):
        return self._domain_centering_mask
Ultima's avatar
Ultima committed
502

Jait Dixit's avatar
Jait Dixit committed
503
504
505
    @cmask_domain.setter
    def cmask_domain(self, cmask):
        self._domain_centering_mask = cmask
Ultima's avatar
Ultima committed
506

Jait Dixit's avatar
Jait Dixit committed
507
508
509
    @property
    def cmask_codomain(self):
        return self._codomain_centering_mask
Ultima's avatar
Ultima committed
510

Jait Dixit's avatar
Jait Dixit committed
511
512
513
    @cmask_codomain.setter
    def cmask_codomain(self, cmask):
        self._codomain_centering_mask = cmask
Ultima's avatar
Ultima committed
514

Jait Dixit's avatar
Jait Dixit committed
515
    @property
Jait Dixit's avatar
Jait Dixit committed
516
517
518
519
520
521
    def sign(self):
        return self._sign

    @sign.setter
    def sign(self, sign):
        self._sign = sign
Ultima's avatar
Ultima committed
522

Jait Dixit's avatar
Jait Dixit committed
523
524

class FFTWLocalTransformInfo(FFTWTransformInfo):
525
    def __init__(self, domain, codomain, local_shape, local_offset_Q,
526
                 fftw_context, **kwargs):
527
528
529
530
531
532
        super(FFTWLocalTransformInfo, self).__init__(domain,
                                                     codomain,
                                                     local_shape,
                                                     local_offset_Q,
                                                     fftw_context,
                                                     **kwargs)
Jait Dixit's avatar
Jait Dixit committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        if codomain.harmonic:
            self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
        else:
            self._fftw_interface = pyfftw.interfaces.numpy_fftn.ifftn

    @property
    def fftw_interface(self):
        return self._fftw_interface

    @fftw_interface.setter
    def fftw_interface(self, interface):
        about.warnings.cprint('WARNING: FFTWLocalTransformInfo fftw_interface \
                               cannot be modified')


class FFTWMPITransfromInfo(FFTWTransformInfo):
549
    def __init__(self, domain, codomain, local_shape, local_offset_Q,
550
                 fftw_context, **kwargs):
551
552
553
554
555
556
        super(FFTWMPITransfromInfo, self).__init__(domain,
                                                   codomain,
                                                   local_shape,
                                                   local_offset_Q,
                                                   fftw_context,
                                                   **kwargs)
557
558
559
560
561
562
563
        # When the domain is 1-dimensional, reshape it so that it can
        # accept input which is also augmented by 1.
        if len(domain.get_shape()) == 1:
            shape = (domain.get_shape()[0], 1)
        else:
            shape = domain.get_shape()

Jait Dixit's avatar
Jait Dixit committed
564
        self._plan = pyfftw.create_mpi_plan(
565
            input_shape=shape,
Jait Dixit's avatar
Jait Dixit committed
566
567
568
569
570
571
            input_dtype='complex128',
            output_dtype='complex128',
            direction='FFTW_FORWARD' if codomain.harmonic else 'FFTW_BACKWARD',
            flags=["FFTW_ESTIMATE"],
            **kwargs
        )
Ultima's avatar
Ultima committed
572

Jait Dixit's avatar
Jait Dixit committed
573
574
575
    @property
    def plan(self):
        return self._plan
Ultima's avatar
Ultima committed
576

Jait Dixit's avatar
Jait Dixit committed
577
578
    @plan.setter
    def plan(self, plan):
Jait Dixit's avatar
Jait Dixit committed
579
580
        about.warnings.cprint('WARNING: FFTWMPITransfromInfo plan \
                               cannot be modified')
Ultima's avatar
Ultima committed
581

582

Jait Dixit's avatar
Jait Dixit committed
583
class GFFT(FFT):
Jait Dixit's avatar
Jait Dixit committed
584

Ultima's avatar
Ultima committed
585
586
587
588
589
    """
        The gfft pendant of a fft object.

        Parameters
        ----------
Jait Dixit's avatar
Jait Dixit committed
590
591
        fft_module_name : String
            Switch between the gfft module used: 'gfft' and 'gfft_dummy'
Ultima's avatar
Ultima committed
592
593

    """
Jait Dixit's avatar
Jait Dixit committed
594

Ultima's avatar
Ultima committed
595
    def __init__(self, fft_module_name):
596
        self.name = fft_module_name
Ultima's avatar
Ultima committed
597
598
599
600
601
        self.fft_machine = gdi.get(fft_module_name)
        if self.fft_machine is None:
            raise ImportError(
                "The gfft(_dummy)-module is needed but not available.")

theos's avatar
theos committed
602
    def transform(self, val, domain, codomain, axes=None, **kwargs):
603
        """
Ultima's avatar
Ultima committed
604
            The gfft transform function.
605
606
607

            Parameters
            ----------
Ultima's avatar
Ultima committed
608
609
610
            val : numpy.ndarray or distributed_data_object
                The value-array of the field which is supposed to
                be transformed.
611

Ultima's avatar
Ultima committed
612
613
            domain : nifty.rg.nifty_rg.rg_space
                The domain of the space which should be transformed.
614

Ultima's avatar
Ultima committed
615
            codomain : nifty.rg.nifty_rg.rg_space
Jait Dixit's avatar
Jait Dixit committed
616
                The target into which the field should be transformed.
617

Jait Dixit's avatar
Jait Dixit committed
618
619
620
            axes : None or tuple
                The axes which should be transformed.

Ultima's avatar
Ultima committed
621
622
623
624
625
            **kwargs : *optional*
                Further kwargs are not processed.

            Returns
            -------
Jait Dixit's avatar
Jait Dixit committed
626
            result : np.ndarray or distributed_data_object
Ultima's avatar
Ultima committed
627
628
                Fourier-transformed pendant of the input field.
        """
Jait Dixit's avatar
Jait Dixit committed
629
630
631
632
633
        # Check if the axes provided are valid given the shape
        if axes is not None and \
                not all(axis in range(len(val.shape)) for axis in axes):
            raise ValueError("ERROR: Provided axes does not match array shape")

634
635
        # GFFT doesn't accept d2o objects as input. Consolidate data from
        # all nodes into numpy.ndarray before proceeding.
Ultima's avatar
Ultima committed
636
        if isinstance(val, distributed_data_object):
theos's avatar
theos committed
637
            temp_inp = val.get_full_data()
Ultima's avatar
Ultima committed
638
        else:
theos's avatar
theos committed
639
            temp_inp = val
640

641
        # Cast input datatype to codomain's dtype
theos's avatar
theos committed
642
        temp_inp = temp_inp.astype(np.complex128, copy=False)
643

644
        # Array for storing the result
theos's avatar
theos committed
645
646
647
        return_val = None

        for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
648

649
650
            # don't copy the whole data array
            if slice_list == [slice(None, None)]:
theos's avatar
theos committed
651
                inp = temp_inp
652
            else:
theos's avatar
theos committed
653
654
655
656
                # initialize the return_val object if needed
                if return_val is None:
                    return_val = np.empty_like(temp_inp)
                inp = temp_inp[slice_list]
657

658
659
            inp = self.fft_machine.gfft(
                inp,
Jait Dixit's avatar
Jait Dixit committed
660
661
                in_ax=[],
                out_ax=[],
662
                ftmachine='fft' if codomain.harmonic else 'ifft',
Jait Dixit's avatar
Jait Dixit committed
663
664
                in_zero_center=map(bool, domain.paradict['zerocenter']),
                out_zero_center=map(bool, codomain.paradict['zerocenter']),
665
666
667
                enforce_hermitian_symmetry=bool(
                    codomain.paradict['complexity']
                ),
Jait Dixit's avatar
Jait Dixit committed
668
669
                W=-1,
                alpha=-1,
Jait Dixit's avatar
Jait Dixit committed
670
671
                verbose=False
            )
theos's avatar
theos committed
672
673
674
675
            if slice_list == [slice(None, None)]:
                return_val = inp
            else:
                return_val[slice_list] = inp
676

677
        if isinstance(val, distributed_data_object):
678
            new_val = val.copy_empty(dtype=codomain.dtype)
theos's avatar
theos committed
679
            new_val.set_full_data(return_val, copy=False)
680
681
682
683
            # If the values living in domain are purely real, the result of
            # the fft is hermitian
            if domain.paradict['complexity'] == 0:
                new_val.hermitian = True
684
            return_val = new_val
685
        else:
theos's avatar
theos committed
686
            return_val = return_val.astype(codomain.dtype, copy=False)
Ultima's avatar
Ultima committed
687

688
        return return_val