From 2057a238d9125e3a3ce0179da3bdfefbd1d25497 Mon Sep 17 00:00:00 2001
From: Jait Dixit <jait.dixit@tum.de>
Date: Tue, 12 Jul 2016 20:43:32 +0200
Subject: [PATCH] WIP: More refactoring

- rename transforms submodule to transformations
- Add a Transformation class which wraps the actual transform class
- Derive and implement RGRGTransformation
- Add tests
---
 __init__.py                                   |  2 +-
 setup.py                                      |  2 +-
 test/test_nifty_transforms.py                 | 63 ++++++++++++++++---
 transformations/__init__.py                   |  3 +
 {transforms => transformations}/fftw.py       | 15 +----
 {transforms => transformations}/gfft.py       | 11 ++--
 transformations/transform.py                  | 32 ++++++++++
 .../transformation.py                         | 62 ++++++++++++------
 transformations/transformation_factory.py     | 29 +++++++++
 transforms/__init__.py                        |  3 -
 transforms/transform_factory.py               | 53 ----------------
 11 files changed, 171 insertions(+), 104 deletions(-)
 create mode 100644 transformations/__init__.py
 rename {transforms => transformations}/fftw.py (97%)
 rename {transforms => transformations}/gfft.py (92%)
 create mode 100644 transformations/transform.py
 rename transforms/transform.py => transformations/transformation.py (54%)
 create mode 100644 transformations/transformation_factory.py
 delete mode 100644 transforms/__init__.py
 delete mode 100644 transforms/transform_factory.py

diff --git a/__init__.py b/__init__.py
index a724d7eee..ce246734e 100644
--- a/__init__.py
+++ b/__init__.py
@@ -93,4 +93,4 @@ from demos import get_demo_dir
 from pickling import _pickle_method, _unpickle_method
 
 #import pyximport; pyximport.install(pyimport = True)
-from transforms import tf as transformator
+from transformations import tf as transformator
diff --git a/setup.py b/setup.py
index cef113359..7e933ad73 100644
--- a/setup.py
+++ b/setup.py
@@ -34,7 +34,7 @@ setup(name="ift_nifty",
       url="http://www.mpa-garching.mpg.de/ift/nifty/",
       packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm",
                 "nifty.operators", "nifty.dummys", "nifty.field_types",
-                "nifty.config", "nifty.power", "nifty.transforms"],
+                "nifty.config", "nifty.power", "nifty.transformations"],
       package_dir={"nifty": ""},
       zip_safe=False,
       dependency_links=[
diff --git a/test/test_nifty_transforms.py b/test/test_nifty_transforms.py
index 295199d57..400761a4d 100644
--- a/test/test_nifty_transforms.py
+++ b/test/test_nifty_transforms.py
@@ -7,7 +7,7 @@ import itertools
 
 from nifty import RGSpace, LMSpace, HPSpace, GLSpace
 from nifty import transformator
-from nifty.transforms.transform import Transform
+from nifty.transformations.transformation import Transformation
 from nifty.rg.rg_space import gc as RG_GC
 import d2o
 
@@ -29,6 +29,28 @@ for name in ['gfft', 'gfft_dummy', 'pyfftw']:
         rg_fft_modules += [name]
 
 
+rg_test_shapes = [(128, 128), (179, 179), (512, 512)]
+
+rg_test_data = np.array(
+    [[0.38405405 + 0.32460996j, 0.02718878 + 0.08326207j,
+      0.78792080 + 0.81192595j, 0.17535687 + 0.68054781j,
+      0.93044845 + 0.71942995j, 0.21179999 + 0.00637665j],
+     [0.10905553 + 0.3027462j, 0.37361237 + 0.68434316j,
+      0.94070232 + 0.34129582j, 0.04658034 + 0.4575192j,
+      0.45057929 + 0.64297612j, 0.01007361 + 0.24953504j],
+     [0.39579662 + 0.70881906j, 0.01614435 + 0.82603832j,
+      0.84036344 + 0.50321592j, 0.87699553 + 0.40337862j,
+      0.11816016 + 0.43332373j, 0.76627757 + 0.66327959j],
+     [0.77272335 + 0.18277367j, 0.93341953 + 0.58105518j,
+      0.27227913 + 0.17458168j, 0.70204032 + 0.81397425j,
+      0.12422993 + 0.19215286j, 0.30897158 + 0.47364969j],
+     [0.24702012 + 0.54534373j, 0.55206013 + 0.98406613j,
+      0.57408167 + 0.55685406j, 0.87991341 + 0.52534323j,
+      0.93912604 + 0.97186519j, 0.77778942 + 0.45812051j],
+     [0.79367868 + 0.48149411j, 0.42484378 + 0.74870011j,
+      0.79611264 + 0.50926774j, 0.35372794 + 0.10468412j,
+      0.46140736 + 0.09449825j, 0.82044644 + 0.95992843j]])
+
 ###############################################################################
 
 class TestRGSpaceTransforms(unittest.TestCase):
@@ -45,6 +67,16 @@ class TestRGSpaceTransforms(unittest.TestCase):
         with assert_raises(TypeError):
             transformator.create(x, y, module=module)
 
+    @parameterized.expand(
+        itertools.product([0, 1, 2], [None, (1, 1), (10, 10)], [False, True]),
+        testcase_func_name=custom_name_func
+    )
+    def test_check_codomain_rgspecific(self, complexity, distances, harmonic):
+        x = RGSpace((8, 8), complexity=complexity,
+                    distances=distances, harmonic=harmonic)
+        assert(Transformation.check_codomain(x, x.get_codomain()))
+        assert(Transformation.check_codomain(x, x.get_codomain()))
+
     @parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
     def test_shapemismatch(self, module):
         x = RGSpace((8, 8))
@@ -55,7 +87,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
             ).transform(b, axes=(0, 1, 2))
 
     @parameterized.expand(
-        itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
+        itertools.product(rg_fft_modules, rg_test_shapes),
         testcase_func_name=custom_name_func
     )
     def test_local_ndarray(self, module, shape):
@@ -69,7 +101,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
         )
 
     @parameterized.expand(
-        itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
+        itertools.product(rg_fft_modules, rg_test_shapes),
         testcase_func_name=custom_name_func
     )
     def test_local_notzero(self, module, shape):
@@ -84,7 +116,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
         )
 
     @parameterized.expand(
-        itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
+        itertools.product(rg_fft_modules, rg_test_shapes),
         testcase_func_name=custom_name_func
     )
     def test_not(self, module, shape):
@@ -98,22 +130,35 @@ class TestRGSpaceTransforms(unittest.TestCase):
             np.fft.fftn(a)
         )
 
-    # ndarray is not contiguous?
     @parameterized.expand(
-        itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
+        itertools.product(rg_test_shapes),
         testcase_func_name=custom_name_func
     )
-    def test_mpi_axesnone(self, module, shape):
+    def test_mpi_axesnone(self, shape):
         x = RGSpace(shape)
         a = np.ones(shape)
         b = d2o.distributed_data_object(a)
         assert np.allclose(
             transformator.create(
-                x, x.get_codomain(), module=module
+                x, x.get_codomain(), module='pyfftw'
+            ).transform(b),
+            np.fft.fftn(a)
+        )
+
+    @parameterized.expand(
+        itertools.product(rg_test_shapes),
+        testcase_func_name=custom_name_func
+    )
+    def test_mpi_axesnone_equal(self, shape):
+        x = RGSpace(shape)
+        a = np.ones(shape)
+        b = d2o.distributed_data_object(a, distribution_strategy='equal')
+        assert np.allclose(
+            transformator.create(
+                x, x.get_codomain(), module='pyfftw'
             ).transform(b),
             np.fft.fftn(a)
         )
 
-    #TODO: check what to do when cannot be distributed
 if __name__ == '__main__':
     unittest.main()
diff --git a/transformations/__init__.py b/transformations/__init__.py
new file mode 100644
index 000000000..9c2658a0c
--- /dev/null
+++ b/transformations/__init__.py
@@ -0,0 +1,3 @@
+from transformation_factory import TransformationFactory
+
+tf = TransformationFactory()
diff --git a/transforms/fftw.py b/transformations/fftw.py
similarity index 97%
rename from transforms/fftw.py
rename to transformations/fftw.py
index 9d66bba97..72d8c1235 100644
--- a/transforms/fftw.py
+++ b/transformations/fftw.py
@@ -18,11 +18,8 @@ class FFTW(Transform):
     """
 
     def __init__(self, domain, codomain):
-        if Transform.check_codomain(domain, codomain):
-            self.domain = domain
-            self.codomain = codomain
-        else:
-            raise ValueError("ERROR: Invalid codomain!")
+        self.domain = domain
+        self.codomain = codomain
 
         if 'pyfftw' not in gdi:
             raise ImportError("The module pyfftw is needed but not available.")
@@ -363,7 +360,7 @@ class FFTW(Transform):
 
         return return_val
 
-    def transform(self, val, axes=None, **kwargs):
+    def transform(self, val, axes, **kwargs):
         """
             The pyfftw transform function.
 
@@ -373,12 +370,6 @@ class FFTW(Transform):
                 The value-array of the field which is supposed to
                 be transformed.
 
-            domain : nifty.rg.nifty_rg.rg_space
-                The domain of the space which should be transformed.
-
-            codomain : nifty.rg.nifty_rg.rg_space
-                The target into which the field should be transformed.
-
             axes: tuple, None
                 The axes which should be transformed.
 
diff --git a/transforms/gfft.py b/transformations/gfft.py
similarity index 92%
rename from transforms/gfft.py
rename to transformations/gfft.py
index f928f87d6..2c860c0a3 100644
--- a/transforms/gfft.py
+++ b/transformations/gfft.py
@@ -17,14 +17,11 @@ class GFFT(Transform):
     """
 
     def __init__(self, domain, codomain, fft_module):
-        if Transform.check_codomain(domain, codomain):
-            self.domain = domain
-            self.codomain = codomain
-            self.fft_machine = fft_module
-        else:
-            raise ValueError("ERROR: Invalid codomain!")
+        self.domain = domain
+        self.codomain = codomain
+        self.fft_machine = fft_module
 
-    def transform(self, val, axes=None, **kwargs):
+    def transform(self, val, axes, **kwargs):
         """
             The gfft transform function.
 
diff --git a/transformations/transform.py b/transformations/transform.py
new file mode 100644
index 000000000..9752a6fc1
--- /dev/null
+++ b/transformations/transform.py
@@ -0,0 +1,32 @@
+from nifty import RGSpace
+from nifty.config import about
+
+import numpy as np
+
+
+class Transform(object):
+    """
+        A generic fft object without any implementation.
+    """
+
+
+    def __init__(self, domain, codomain):
+        pass
+
+    def transform(self, val, axes, **kwargs):
+        """
+            A generic ff-transform function.
+
+            Parameters
+            ----------
+            field_val : distributed_data_object
+                The value-array of the field which is supposed to
+                be transformed.
+
+            domain : nifty.rg.nifty_rg.rg_space
+                The domain of the space which should be transformed.
+
+            codomain : nifty.rg.nifty_rg.rg_space
+                The taget into which the field should be transformed.
+        """
+        raise NotImplementedError
diff --git a/transforms/transform.py b/transformations/transformation.py
similarity index 54%
rename from transforms/transform.py
rename to transformations/transformation.py
index a4c681d6a..aceaff0ad 100644
--- a/transforms/transform.py
+++ b/transformations/transformation.py
@@ -1,12 +1,16 @@
+from fftw import FFTW
+from gfft import  GFFT
+
+from nifty.config import about, dependency_injector as gdi
 from nifty import RGSpace
-from nifty.config import about
 
 import numpy as np
 
 
-class Transform(object):
+class Transformation(object):
     """
-        A generic fft object without any implementation.
+        A generic transformation which defines a static check_codomain
+        method for all transforms.
     """
 
     @staticmethod
@@ -69,23 +73,45 @@ class Transform(object):
 
         return True
 
-    def __init__(self, domain, codomain):
+    def __init__(self, domain, codomain, module=None):
         pass
 
-    def transform(self, val, axes, **kwargs):
-        """
-            A generic ff-transform function.
+    def transform(self, val, axes=None, **kwargs):
+        raise NotImplementedError
 
-            Parameters
-            ----------
-            field_val : distributed_data_object
-                The value-array of the field which is supposed to
-                be transformed.
 
-            domain : nifty.rg.nifty_rg.rg_space
-                The domain of the space which should be transformed.
+class RGRGTransformation(Transformation):
+    def __init__(self, domain, codomain, module=None):
+        if Transformation.check_codomain(domain, codomain):
+            if module is None:
+                if gdi.get('pyfftw') is None:
+                    if gdi.get('gfft') is None:
+                        self._transform =\
+                            GFFT(domain, codomain, gdi.get('gfft_dummy'))
+                    else:
+                        self._transform =\
+                            GFFT(domain, codomain, gdi.get('gfft'))
+                self._transform = FFTW(domain, codomain)
+            else:
+                if module == 'pyfftw':
+                    if gdi.get('pyfftw') is not None:
+                        self._transform = FFTW(domain, codomain)
+                    else:
+                        raise RuntimeError("ERROR: pyfftw is not available.")
+                elif module == 'gfft':
+                    if gdi.get('gfft') is not None:
+                        self._transform =\
+                            GFFT(domain, codomain, gdi.get('gfft'))
+                    else:
+                        raise RuntimeError("ERROR: gfft is not available.")
+                elif module == 'gfft_dummy':
+                    self._transform =\
+                        GFFT(domain, codomain, gdi.get('gfft_dummy'))
+                else:
+                    raise ValueError('Given FFT module is not known: ' +
+                                     str(module))
+        else:
+            raise ValueError("ERROR: Incompatible codomain!")
 
-            codomain : nifty.rg.nifty_rg.rg_space
-                The taget into which the field should be transformed.
-        """
-        raise NotImplementedError
+    def transform(self, val, axes=None, **kwargs):
+        return self._transform.transform(val, axes, **kwargs)
\ No newline at end of file
diff --git a/transformations/transformation_factory.py b/transformations/transformation_factory.py
new file mode 100644
index 000000000..9ca4dca3e
--- /dev/null
+++ b/transformations/transformation_factory.py
@@ -0,0 +1,29 @@
+import numpy as np
+
+from nifty.rg import RGSpace
+from nifty.lm import GLSpace, HPSpace, LMSpace
+
+from transformation import RGRGTransformation
+
+
+class TransformationFactory(object):
+    """
+        Transform factory which generates transform objects
+    """
+
+    def __init__(self):
+        # cache for storing the transform objects
+        self.cache = {}
+
+    def _get_transform(self, domain, codomain, module):
+        if isinstance(domain, RGSpace):
+            return RGRGTransformation(domain, codomain, module)
+
+    def create(self, domain, codomain, module=None):
+        key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
+                                   (179 * module.__hash__()))
+
+        if key not in self.cache:
+            self.cache[key] = self._get_transform(domain, codomain, module)
+
+        return self.cache[key]
\ No newline at end of file
diff --git a/transforms/__init__.py b/transforms/__init__.py
deleted file mode 100644
index 7cd28c54d..000000000
--- a/transforms/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from transform_factory import TransformFactory
-
-tf = TransformFactory()
diff --git a/transforms/transform_factory.py b/transforms/transform_factory.py
deleted file mode 100644
index afd174fc7..000000000
--- a/transforms/transform_factory.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import numpy as np
-
-from nifty.rg import RGSpace
-from nifty.lm import GLSpace, HPSpace, LMSpace
-from nifty.config import about, dependency_injector as gdi
-from gfft import GFFT
-from fftw import FFTW
-
-
-class TransformFactory(object):
-    """
-        Transform factory which generates transform objects
-    """
-
-    def __init__(self):
-        # cache for storing the transform objects
-        self.cache = {}
-
-    def _get_transform(self, domain, codomain, module):
-        if isinstance(domain, RGSpace):
-            # fftw -> gfft -> gfft_dummy
-            if module is None:
-                if gdi.get('pyfftw') is None:
-                    if gdi.get('gfft') is None:
-                        return GFFT(domain, codomain, gdi.get('gfft_dummy'))
-                    else:
-                        return GFFT(domain, codomain, gdi.get('gfft'))
-                return FFTW(domain, codomain)
-            else:
-                if module == 'pyfftw':
-                    if gdi.get('pyfftw') is not None:
-                        return FFTW(domain, codomain)
-                    else:
-                        raise RuntimeError("ERROR: pyfftw is not available.")
-                elif module == 'gfft':
-                    if gdi.get('gfft') is not None:
-                        return GFFT(domain, codomain, gdi.get('gfft'))
-                    else:
-                        raise RuntimeError("ERROR: gfft is not available.")
-                elif module == 'gfft_dummy':
-                    return GFFT(domain, codomain, gdi.get('gfft_dummy'))
-                else:
-                    raise ValueError('Given FFT module is not known: ' +
-                                     str(module))
-
-    def create(self, domain, codomain, module=None):
-        key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
-                                   (179 * module.__hash__()))
-
-        if key not in self.cache:
-            self.cache[key] = self._get_transform(domain, codomain, module)
-
-        return self.cache[key]
\ No newline at end of file
-- 
GitLab