From 9d4cead3a41cad9fdb8e4563bbfa39fee4bf2a68 Mon Sep 17 00:00:00 2001
From: Theo Steininger <theos@mpa-garching.mpg.de>
Date: Wed, 19 Apr 2017 12:18:55 +0200
Subject: [PATCH] Removed gfft.

---
 nifty/config/d2o_config.py                    |   6 +-
 nifty/config/nifty_config.py                  |  14 +-
 .../transformations/rg_transforms.py          | 296 +++++++++---------
 .../transformations/rgrgtransformation.py     |  23 +-
 setup.py                                      |   6 +-
 5 files changed, 168 insertions(+), 177 deletions(-)

diff --git a/nifty/config/d2o_config.py b/nifty/config/d2o_config.py
index 9e20a2669..414eb7b5f 100644
--- a/nifty/config/d2o_config.py
+++ b/nifty/config/d2o_config.py
@@ -24,6 +24,6 @@ import keepers
 d2o_configuration = keepers.get_Configuration(
                     name='D2O',
                     file_name='D2O.conf',
-                    search_pathes=[os.path.expanduser('~') + "/.config/nifty/",
-                                   os.path.expanduser('~') + "/.config/",
-                                   './'])
+                    search_paths=[os.path.expanduser('~') + "/.config/nifty/",
+                                  os.path.expanduser('~') + "/.config/",
+                                  './'])
diff --git a/nifty/config/nifty_config.py b/nifty/config/nifty_config.py
index a8fff831f..4211e8bc0 100644
--- a/nifty/config/nifty_config.py
+++ b/nifty/config/nifty_config.py
@@ -25,8 +25,6 @@ import keepers
 # Setup the dependency injector
 dependency_injector = keepers.DependencyInjector(
                                    [('mpi4py.MPI', 'MPI'),
-                                    'gfft',
-                                    ('nifty.dummys.gfft_dummy', 'gfft_dummy'),
                                     'healpy',
                                     'libsharp_wrapper_gl'])
 
@@ -36,8 +34,9 @@ dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
 # Initialize the variables
 variable_fft_module = keepers.Variable(
                                'fft_module',
-                               ['pyfftw', 'gfft', 'gfft_dummy'],
-                               lambda z: z in dependency_injector)
+                               ['fftw', 'numpy'],
+                               lambda z: (('pyfftw' in dependency_injector)
+                                          if z == 'fftw' else True))
 
 
 def _healpy_validator(use_healpy):
@@ -97,12 +96,11 @@ nifty_configuration = keepers.get_Configuration(
                             variable_default_field_dtype,
                             variable_default_distribution_strategy],
                  file_name='NIFTy.conf',
-                 search_pathes=[os.path.expanduser('~') + "/.config/nifty/",
-                                os.path.expanduser('~') + "/.config/",
-                                './'])
+                 search_paths=[os.path.expanduser('~') + "/.config/nifty/",
+                               os.path.expanduser('~') + "/.config/",
+                               './'])
 
 ########
-########
 
 try:
     nifty_configuration.load()
diff --git a/nifty/operators/fft_operator/transformations/rg_transforms.py b/nifty/operators/fft_operator/transformations/rg_transforms.py
index d12ade8cc..ba928f757 100644
--- a/nifty/operators/fft_operator/transformations/rg_transforms.py
+++ b/nifty/operators/fft_operator/transformations/rg_transforms.py
@@ -33,47 +33,10 @@ class Transform(Loggable, object):
         A generic fft object without any implementation.
     """
 
-    def __init__(self, domain, codomain):
-        pass
-
-    def transform(self, val, axes, **kwargs):
-        """
-            A generic ff-transform function.
-
-            Parameters
-            ----------
-            field_val : distributed_data_object
-                The value-array of the field which is supposed to
-                be transformed.
-
-            domain : nifty.rg.nifty_rg.rg_space
-                The domain of the space which should be transformed.
-
-            codomain : nifty.rg.nifty_rg.rg_space
-                The taget into which the field should be transformed.
-        """
-        raise NotImplementedError
-
-
-class FFTW(Transform):
-    """
-        The pyfftw pendant of a fft object.
-    """
-
     def __init__(self, domain, codomain):
         self.domain = domain
         self.codomain = codomain
 
-        if 'pyfftw' not in gdi:
-            raise ImportError("The module pyfftw is needed but not available.")
-
-        # Enable caching for pyfftw.interfaces
-        pyfftw.interfaces.cache.enable()
-
-        # The plan_dict stores the FFTWTransformInfo objects which correspond
-        # to a certain set of (field_val, domain, codomain) sets.
-        self.info_dict = {}
-
         # initialize the dictionary which stores the values from
         # get_centering_mask
         self.centering_mask_dict = {}
@@ -130,7 +93,10 @@ class FFTW(Transform):
         to_center = np.array(temp_to_center)
         # cast the offset_input into the shape of to_center
         offset = np.zeros(to_center.shape, dtype=int)
-        offset[0] = int(offset_input)
+        # if the first dimension has length 1 and has an offset, restore the
+        # global minus by hand
+        if not size_one_dimensions[0]:
+            offset[0] = int(offset_input)
         # check for dimension match
         if to_center.size != dimensions.size:
             raise TypeError(
@@ -179,34 +145,14 @@ class FFTW(Transform):
                 else:
                     temp_slice += (slice(None),)
             centering_mask = centering_mask[temp_slice]
+            # if the first dimension has length 1 and has an offset, restore
+            # the global minus by hand
+            if size_one_dimensions[0] and offset_input:
+                centering_mask *= -1
+
             self.centering_mask_dict[temp_id] = centering_mask
         return self.centering_mask_dict[temp_id]
 
-    def _get_transform_info(self, domain, codomain, axes, local_shape,
-                            local_offset_Q, is_local, transform_shape=None,
-                            **kwargs):
-        # generate a id-tuple which identifies the domain-codomain setting
-        temp_id = (domain.__hash__() ^
-                   (101 * codomain.__hash__()) ^
-                   (211 * transform_shape.__hash__()) ^
-                   (131 * is_local.__hash__())
-                   )
-
-        # generate the plan_and_info object if not already there
-        if temp_id not in self.info_dict:
-            if is_local:
-                self.info_dict[temp_id] = FFTWLocalTransformInfo(
-                    domain, codomain, axes, local_shape,
-                    local_offset_Q, self, **kwargs
-                )
-            else:
-                self.info_dict[temp_id] = FFTWMPITransfromInfo(
-                    domain, codomain, axes, local_shape,
-                    local_offset_Q, self, transform_shape, **kwargs
-                )
-
-        return self.info_dict[temp_id]
-
     def _apply_mask(self, val, mask, axes):
         """
             Apply centering mask to an array.
@@ -233,9 +179,71 @@ class FFTW(Transform):
                 [y if x in axes else 1
                  for x, y in enumerate(val.shape)]
             )
-
         return val * mask
 
+    def transform(self, val, axes, **kwargs):
+        """
+            A generic ff-transform function.
+
+            Parameters
+            ----------
+            field_val : distributed_data_object
+                The value-array of the field which is supposed to
+                be transformed.
+
+            domain : nifty.rg.nifty_rg.rg_space
+                The domain of the space which should be transformed.
+
+            codomain : nifty.rg.nifty_rg.rg_space
+                The taget into which the field should be transformed.
+        """
+        raise NotImplementedError
+
+
+class FFTW(Transform):
+    """
+        The pyfftw pendant of a fft object.
+    """
+
+    def __init__(self, domain, codomain):
+
+        if 'pyfftw' not in gdi:
+            raise ImportError("The module pyfftw is needed but not available.")
+
+        super(FFTW, self).__init__(domain, codomain)
+
+        # Enable caching for pyfftw.interfaces
+        pyfftw.interfaces.cache.enable()
+
+        # The plan_dict stores the FFTWTransformInfo objects which correspond
+        # to a certain set of (field_val, domain, codomain) sets.
+        self.info_dict = {}
+
+    def _get_transform_info(self, domain, codomain, axes, local_shape,
+                            local_offset_Q, is_local, transform_shape=None,
+                            **kwargs):
+        # generate a id-tuple which identifies the domain-codomain setting
+        temp_id = (domain.__hash__() ^
+                   (101 * codomain.__hash__()) ^
+                   (211 * transform_shape.__hash__()) ^
+                   (131 * is_local.__hash__())
+                   )
+
+        # generate the plan_and_info object if not already there
+        if temp_id not in self.info_dict:
+            if is_local:
+                self.info_dict[temp_id] = FFTWLocalTransformInfo(
+                    domain, codomain, axes, local_shape,
+                    local_offset_Q, self, **kwargs
+                )
+            else:
+                self.info_dict[temp_id] = FFTWMPITransfromInfo(
+                    domain, codomain, axes, local_shape,
+                    local_offset_Q, self, transform_shape, **kwargs
+                )
+
+        return self.info_dict[temp_id]
+
     def _atomic_mpi_transform(self, val, info, axes):
         # Apply codomain centering mask
         if reduce(lambda x, y: x + y, self.codomain.zerocenter):
@@ -333,7 +341,6 @@ class FFTW(Transform):
             np.concatenate([[0, ], val.distributor.all_local_slices[:, 2]])
         )
         local_offset_Q = bool(local_offset_list[val.distributor.comm.rank] % 2)
-
         result_dtype = np.result_type(np.complex, self.codomain.dtype)
         return_val = val.copy_empty(global_shape=val.shape,
                                     dtype=result_dtype)
@@ -472,45 +479,33 @@ class FFTWTransformInfo(object):
         shape = (local_shape if axes is None else
                  [y for x, y in enumerate(local_shape) if x in axes])
 
-        self.cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
-                                                            shape,
-                                                            local_offset_Q)
+        self._cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
+                                                             shape,
+                                                             local_offset_Q)
 
-        self.cmask_codomain = fftw_context.get_centering_mask(
-                                                        codomain.zerocenter,
-                                                        shape,
-                                                        local_offset_Q)
+        self._cmask_codomain = fftw_context.get_centering_mask(
+                                                         codomain.zerocenter,
+                                                         shape,
+                                                         local_offset_Q)
 
         # If both domain and codomain are zero-centered the result,
         # will get a global minus. Store the sign to correct it.
-        self.sign = (-1) ** np.sum(np.array(domain.zerocenter) *
-                                   np.array(codomain.zerocenter) *
-                                   (np.array(domain.shape) // 2 % 2))
+        self._sign = (-1) ** np.sum(np.array(domain.zerocenter) *
+                                    np.array(codomain.zerocenter) *
+                                    (np.array(domain.shape) // 2 % 2))
 
     @property
     def cmask_domain(self):
-        return self._domain_centering_mask
-
-    @cmask_domain.setter
-    def cmask_domain(self, cmask):
-        self._domain_centering_mask = cmask
+        return self._cmask_domain
 
     @property
     def cmask_codomain(self):
-        return self._codomain_centering_mask
-
-    @cmask_codomain.setter
-    def cmask_codomain(self, cmask):
-        self._codomain_centering_mask = cmask
+        return self._cmask_codomain
 
     @property
     def sign(self):
         return self._sign
 
-    @sign.setter
-    def sign(self, sign):
-        self._sign = sign
-
 
 class FFTWLocalTransformInfo(FFTWTransformInfo):
     def __init__(self, domain, codomain, axes, local_shape,
@@ -556,9 +551,9 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
         return self._plan
 
 
-class GFFT(Transform):
+class NUMPYFFT(Transform):
     """
-        The gfft pendant of a fft object.
+        The numpy fft pendant of a fft object.
 
         Parameters
         ----------
@@ -567,33 +562,21 @@ class GFFT(Transform):
 
     """
 
-    def __init__(self, domain, codomain, fft_module=None):
-        if fft_module is None:
-            fft_module = gdi['gfft_dummy']
-
-        # GFFT only works for domains with an even number of pixels per axis
-        if np.any(np.array(domain.shape) % 2):
-            raise AttributeError("GFFT needs an even number of pixels per "
-                                 "axis of domain. Got: %s" % str(domain.shape))
-        self.domain = domain
-        self.codomain = codomain
-        self.fft_machine = fft_module
-
     def transform(self, val, axes, **kwargs):
         """
-            The gfft transform function.
+            The pyfftw transform function.
 
             Parameters
             ----------
-            val : numpy.ndarray or distributed_data_object
+            val : distributed_data_object or numpy.ndarray
                 The value-array of the field which is supposed to
                 be transformed.
 
-            axes : None or tuple
+            axes: tuple, None
                 The axes which should be transformed.
 
             **kwargs : *optional*
-                Further kwargs are not processed.
+                Further kwargs are passed to the create_mpi_plan routine.
 
             Returns
             -------
@@ -605,51 +588,70 @@ class GFFT(Transform):
                 not all(axis in range(len(val.shape)) for axis in axes):
             raise ValueError("Provided axes does not match array shape")
 
-        # GFFT doesn't accept d2o objects as input. Consolidate data from
-        # all nodes into numpy.ndarray before proceeding.
-        if isinstance(val, distributed_data_object):
-            temp_inp = val.get_full_data()
-        else:
-            temp_inp = val
-
-        # Array for storing the result
-        return_val = None
         result_dtype = np.result_type(np.complex, self.codomain.dtype)
+        return_val = val.copy_empty(global_shape=val.shape,
+                                    dtype=result_dtype)
 
-        for slice_list in utilities.get_slice_list(temp_inp.shape, axes):
+        if (axes is None) or (0 in axes) or \
+           (val.distribution_strategy not in STRATEGIES['slicing']):
 
-            # don't copy the whole data array
-            if slice_list == [slice(None, None)]:
-                inp = temp_inp
+            if val.distribution_strategy == 'not':
+                local_val = val.get_local_data(copy=False)
             else:
-                # initialize the return_val object if needed
-                if return_val is None:
-                    return_val = np.empty_like(temp_inp, dtype=result_dtype)
-                inp = temp_inp[slice_list]
-
-            inp = self.fft_machine.gfft(
-                inp,
-                in_ax=[],
-                out_ax=[],
-                ftmachine='fft' if self.codomain.harmonic else 'ifft',
-                in_zero_center=map(bool, self.domain.zerocenter),
-                out_zero_center=map(bool, self.codomain.zerocenter),
-                # enforce_hermitian_symmetry=bool(self.codomain.complexity),
-                enforce_hermitian_symmetry=False,
-                W=-1,
-                alpha=-1,
-                verbose=False
-            )
-            if slice_list == [slice(None, None)]:
-                return_val = inp
-            else:
-                return_val[slice_list] = inp
+                local_val = val.get_full_data()
+
+            result_data = self._atomic_transform(local_val=local_val,
+                                                 axes=axes,
+                                                 local_offset_Q=False)
+            return_val.set_full_data(result_data, copy=False)
 
-        if isinstance(val, distributed_data_object):
-            new_val = val.copy_empty(dtype=result_dtype)
-            new_val.set_full_data(return_val, copy=False)
-            return_val = new_val
         else:
-            return_val = return_val.astype(result_dtype, copy=False)
+            local_offset_list = np.cumsum(
+                    np.concatenate([[0, ],
+                                    val.distributor.all_local_slices[:, 2]]))
+            local_offset_Q = \
+                bool(local_offset_list[val.distributor.comm.rank] % 2)
+
+            local_val = val.get_local_data()
+            result_data = self._atomic_transform(local_val=local_val,
+                                                 axes=axes,
+                                                 local_offset_Q=local_offset_Q)
+
+            return_val.set_local_data(result_data, copy=False)
 
         return return_val
+
+    def _atomic_transform(self, local_val, axes, local_offset_Q):
+
+        # some auxiliaries for the mask computation
+        local_shape = local_val.shape
+        shape = (local_shape if axes is None else
+                 [y for x, y in enumerate(local_shape) if x in axes])
+
+        # Apply codomain centering mask
+        if reduce(lambda x, y: x + y, self.codomain.zerocenter):
+            temp_val = np.copy(local_val)
+            mask = self.get_centering_mask(self.codomain.zerocenter,
+                                           shape,
+                                           local_offset_Q)
+            local_val = self._apply_mask(temp_val, mask, axes)
+
+        # perform the transformation
+        result_val = np.fft.fftn(local_val, axes=axes)
+
+        # Apply domain centering mask
+        if reduce(lambda x, y: x + y, self.domain.zerocenter):
+            mask = self.get_centering_mask(self.domain.zerocenter,
+                                           shape,
+                                           local_offset_Q)
+            result_val = self._apply_mask(result_val, mask, axes)
+
+        # If both domain and codomain are zero-centered the result,
+        # will get a global minus. Store the sign to correct it.
+        sign = (-1) ** np.sum(np.array(self.domain.zerocenter) *
+                              np.array(self.codomain.zerocenter) *
+                              (np.array(self.domain.shape) // 2 % 2))
+        if sign != 1:
+            result_val *= sign
+
+        return result_val
diff --git a/nifty/operators/fft_operator/transformations/rgrgtransformation.py b/nifty/operators/fft_operator/transformations/rgrgtransformation.py
index ea3e9dc2b..eca50fc72 100644
--- a/nifty/operators/fft_operator/transformations/rgrgtransformation.py
+++ b/nifty/operators/fft_operator/transformations/rgrgtransformation.py
@@ -18,8 +18,7 @@
 
 import numpy as np
 from transformation import Transformation
-from rg_transforms import FFTW, GFFT
-from nifty.config import dependency_injector as gdi
+from rg_transforms import FFTW, NUMPYFFT
 from nifty import RGSpace, nifty_configuration
 
 
@@ -29,26 +28,18 @@ class RGRGTransformation(Transformation):
                                                  module=module)
 
         if module is None:
-            if nifty_configuration['fft_module'] == 'pyfftw':
+            if nifty_configuration['fft_module'] == 'fftw':
                 self._transform = FFTW(self.domain, self.codomain)
-            elif (nifty_configuration['fft_module'] == 'gfft' or
-                  nifty_configuration['fft_module'] == 'gfft_dummy'):
-                self._transform = \
-                    GFFT(self.domain,
-                         self.codomain,
-                         gdi.get(nifty_configuration['fft_module']))
+            elif nifty_configuration['fft_module'] == 'numpy':
+                self._transform = NUMPYFFT(self.domain, self.codomain)
             else:
                 raise ValueError('ERROR: unknow default FFT module:' +
                                  nifty_configuration['fft_module'])
         else:
-            if module == 'pyfftw':
+            if module == 'fftw':
                 self._transform = FFTW(self.domain, self.codomain)
-            elif module == 'gfft':
-                self._transform = \
-                    GFFT(self.domain, self.codomain, gdi.get('gfft'))
-            elif module == 'gfft_dummy':
-                self._transform = \
-                    GFFT(self.domain, self.codomain, gdi.get('gfft_dummy'))
+            elif module == 'numpy':
+                self._transform = NUMPYFFT(self.domain, self.codomain)
             else:
                 raise ValueError('ERROR: unknow FFT module:' + module)
 
diff --git a/setup.py b/setup.py
index ed184f31a..4d8c1a9c8 100644
--- a/setup.py
+++ b/setup.py
@@ -42,9 +42,9 @@ setup(name="ift_nifty",
       ]),
       include_dirs=[numpy.get_include()],
       dependency_links=[
-        'git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers-0.3.5',
-        'git+https://gitlab.mpcdf.mpg.de/ift/d2o.git#egg=d2o-1.0.7'],
-      install_requires=['keepers>=0.3.5', 'd2o>=1.0.7'],
+        'git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers-0.3.7',
+        'git+https://gitlab.mpcdf.mpg.de/ift/d2o.git#egg=d2o-1.0.8'],
+      install_requires=['keepers>=0.3.7', 'd2o>=1.0.8'],
       package_data={'nifty.demos': ['demo_faraday_map.npy'],
                     },
       license="GPLv3",
-- 
GitLab