diff --git a/__init__.py b/__init__.py index a724d7eeee5be38720256a15b2fef281dbff14fb..ce246734e04c3f3bc1402581e3e1e1ffb8e46bef 100644 --- a/__init__.py +++ b/__init__.py @@ -93,4 +93,4 @@ from demos import get_demo_dir from pickling import _pickle_method, _unpickle_method #import pyximport; pyximport.install(pyimport = True) -from transforms import tf as transformator +from transformations import tf as transformator diff --git a/setup.py b/setup.py index cef1133592eddc03151830dfc2ba64cfb377feac..7e933ad73cc6a9cdbe5ce9a4a8f63fd4441149c5 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ setup(name="ift_nifty", url="http://www.mpa-garching.mpg.de/ift/nifty/", packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm", "nifty.operators", "nifty.dummys", "nifty.field_types", - "nifty.config", "nifty.power", "nifty.transforms"], + "nifty.config", "nifty.power", "nifty.transformations"], package_dir={"nifty": ""}, zip_safe=False, dependency_links=[ diff --git a/test/test_nifty_transforms.py b/test/test_nifty_transforms.py index 295199d57c92db9d19d37517de2b59ea938fe3ab..400761a4d135a8ebefa8ec6ad5350569925ae028 100644 --- a/test/test_nifty_transforms.py +++ b/test/test_nifty_transforms.py @@ -7,7 +7,7 @@ import itertools from nifty import RGSpace, LMSpace, HPSpace, GLSpace from nifty import transformator -from nifty.transforms.transform import Transform +from nifty.transformations.transformation import Transformation from nifty.rg.rg_space import gc as RG_GC import d2o @@ -29,6 +29,28 @@ for name in ['gfft', 'gfft_dummy', 'pyfftw']: rg_fft_modules += [name] +rg_test_shapes = [(128, 128), (179, 179), (512, 512)] + +rg_test_data = np.array( + [[0.38405405 + 0.32460996j, 0.02718878 + 0.08326207j, + 0.78792080 + 0.81192595j, 0.17535687 + 0.68054781j, + 0.93044845 + 0.71942995j, 0.21179999 + 0.00637665j], + [0.10905553 + 0.3027462j, 0.37361237 + 0.68434316j, + 0.94070232 + 0.34129582j, 0.04658034 + 0.4575192j, + 0.45057929 + 0.64297612j, 0.01007361 + 0.24953504j], + [0.39579662 + 0.70881906j, 0.01614435 + 0.82603832j, + 0.84036344 + 0.50321592j, 0.87699553 + 0.40337862j, + 0.11816016 + 0.43332373j, 0.76627757 + 0.66327959j], + [0.77272335 + 0.18277367j, 0.93341953 + 0.58105518j, + 0.27227913 + 0.17458168j, 0.70204032 + 0.81397425j, + 0.12422993 + 0.19215286j, 0.30897158 + 0.47364969j], + [0.24702012 + 0.54534373j, 0.55206013 + 0.98406613j, + 0.57408167 + 0.55685406j, 0.87991341 + 0.52534323j, + 0.93912604 + 0.97186519j, 0.77778942 + 0.45812051j], + [0.79367868 + 0.48149411j, 0.42484378 + 0.74870011j, + 0.79611264 + 0.50926774j, 0.35372794 + 0.10468412j, + 0.46140736 + 0.09449825j, 0.82044644 + 0.95992843j]]) + ############################################################################### class TestRGSpaceTransforms(unittest.TestCase): @@ -45,6 +67,16 @@ class TestRGSpaceTransforms(unittest.TestCase): with assert_raises(TypeError): transformator.create(x, y, module=module) + @parameterized.expand( + itertools.product([0, 1, 2], [None, (1, 1), (10, 10)], [False, True]), + testcase_func_name=custom_name_func + ) + def test_check_codomain_rgspecific(self, complexity, distances, harmonic): + x = RGSpace((8, 8), complexity=complexity, + distances=distances, harmonic=harmonic) + assert(Transformation.check_codomain(x, x.get_codomain())) + assert(Transformation.check_codomain(x, x.get_codomain())) + @parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func) def test_shapemismatch(self, module): x = RGSpace((8, 8)) @@ -55,7 +87,7 @@ class TestRGSpaceTransforms(unittest.TestCase): ).transform(b, axes=(0, 1, 2)) @parameterized.expand( - itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), + itertools.product(rg_fft_modules, rg_test_shapes), testcase_func_name=custom_name_func ) def test_local_ndarray(self, module, shape): @@ -69,7 +101,7 @@ class TestRGSpaceTransforms(unittest.TestCase): ) @parameterized.expand( - itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), + itertools.product(rg_fft_modules, rg_test_shapes), testcase_func_name=custom_name_func ) def test_local_notzero(self, module, shape): @@ -84,7 +116,7 @@ class TestRGSpaceTransforms(unittest.TestCase): ) @parameterized.expand( - itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), + itertools.product(rg_fft_modules, rg_test_shapes), testcase_func_name=custom_name_func ) def test_not(self, module, shape): @@ -98,22 +130,35 @@ class TestRGSpaceTransforms(unittest.TestCase): np.fft.fftn(a) ) - # ndarray is not contiguous? @parameterized.expand( - itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), + itertools.product(rg_test_shapes), testcase_func_name=custom_name_func ) - def test_mpi_axesnone(self, module, shape): + def test_mpi_axesnone(self, shape): x = RGSpace(shape) a = np.ones(shape) b = d2o.distributed_data_object(a) assert np.allclose( transformator.create( - x, x.get_codomain(), module=module + x, x.get_codomain(), module='pyfftw' + ).transform(b), + np.fft.fftn(a) + ) + + @parameterized.expand( + itertools.product(rg_test_shapes), + testcase_func_name=custom_name_func + ) + def test_mpi_axesnone_equal(self, shape): + x = RGSpace(shape) + a = np.ones(shape) + b = d2o.distributed_data_object(a, distribution_strategy='equal') + assert np.allclose( + transformator.create( + x, x.get_codomain(), module='pyfftw' ).transform(b), np.fft.fftn(a) ) - #TODO: check what to do when cannot be distributed if __name__ == '__main__': unittest.main() diff --git a/transformations/__init__.py b/transformations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2658a0ccf235c99835f3dd5ce0faf8c3d2b12b --- /dev/null +++ b/transformations/__init__.py @@ -0,0 +1,3 @@ +from transformation_factory import TransformationFactory + +tf = TransformationFactory() diff --git a/transforms/fftw.py b/transformations/fftw.py similarity index 97% rename from transforms/fftw.py rename to transformations/fftw.py index 9d66bba977916c662f4523a4df2d3b7004db1c96..72d8c123541b3ed9a81bd0d2d8acb5cec5e5e799 100644 --- a/transforms/fftw.py +++ b/transformations/fftw.py @@ -18,11 +18,8 @@ class FFTW(Transform): """ def __init__(self, domain, codomain): - if Transform.check_codomain(domain, codomain): - self.domain = domain - self.codomain = codomain - else: - raise ValueError("ERROR: Invalid codomain!") + self.domain = domain + self.codomain = codomain if 'pyfftw' not in gdi: raise ImportError("The module pyfftw is needed but not available.") @@ -363,7 +360,7 @@ class FFTW(Transform): return return_val - def transform(self, val, axes=None, **kwargs): + def transform(self, val, axes, **kwargs): """ The pyfftw transform function. @@ -373,12 +370,6 @@ class FFTW(Transform): The value-array of the field which is supposed to be transformed. - domain : nifty.rg.nifty_rg.rg_space - The domain of the space which should be transformed. - - codomain : nifty.rg.nifty_rg.rg_space - The target into which the field should be transformed. - axes: tuple, None The axes which should be transformed. diff --git a/transforms/gfft.py b/transformations/gfft.py similarity index 92% rename from transforms/gfft.py rename to transformations/gfft.py index f928f87d69fed3c6fa9c8e4db6b750534c4b74a9..2c860c0a30df3976ab40c3174c9d544144d7ebc6 100644 --- a/transforms/gfft.py +++ b/transformations/gfft.py @@ -17,14 +17,11 @@ class GFFT(Transform): """ def __init__(self, domain, codomain, fft_module): - if Transform.check_codomain(domain, codomain): - self.domain = domain - self.codomain = codomain - self.fft_machine = fft_module - else: - raise ValueError("ERROR: Invalid codomain!") + self.domain = domain + self.codomain = codomain + self.fft_machine = fft_module - def transform(self, val, axes=None, **kwargs): + def transform(self, val, axes, **kwargs): """ The gfft transform function. diff --git a/transformations/transform.py b/transformations/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..9752a6fc13ab670953617380764dc1e358e9bf1f --- /dev/null +++ b/transformations/transform.py @@ -0,0 +1,32 @@ +from nifty import RGSpace +from nifty.config import about + +import numpy as np + + +class Transform(object): + """ + A generic fft object without any implementation. + """ + + + def __init__(self, domain, codomain): + pass + + def transform(self, val, axes, **kwargs): + """ + A generic ff-transform function. + + Parameters + ---------- + field_val : distributed_data_object + The value-array of the field which is supposed to + be transformed. + + domain : nifty.rg.nifty_rg.rg_space + The domain of the space which should be transformed. + + codomain : nifty.rg.nifty_rg.rg_space + The taget into which the field should be transformed. + """ + raise NotImplementedError diff --git a/transforms/transform.py b/transformations/transformation.py similarity index 54% rename from transforms/transform.py rename to transformations/transformation.py index a4c681d6a7c15e6b6b501f9793f4cc0c0463a0c8..aceaff0ad121a54d388ce2a9f7a7631fc34db5f3 100644 --- a/transforms/transform.py +++ b/transformations/transformation.py @@ -1,12 +1,16 @@ +from fftw import FFTW +from gfft import GFFT + +from nifty.config import about, dependency_injector as gdi from nifty import RGSpace -from nifty.config import about import numpy as np -class Transform(object): +class Transformation(object): """ - A generic fft object without any implementation. + A generic transformation which defines a static check_codomain + method for all transforms. """ @staticmethod @@ -69,23 +73,45 @@ class Transform(object): return True - def __init__(self, domain, codomain): + def __init__(self, domain, codomain, module=None): pass - def transform(self, val, axes, **kwargs): - """ - A generic ff-transform function. + def transform(self, val, axes=None, **kwargs): + raise NotImplementedError - Parameters - ---------- - field_val : distributed_data_object - The value-array of the field which is supposed to - be transformed. - domain : nifty.rg.nifty_rg.rg_space - The domain of the space which should be transformed. +class RGRGTransformation(Transformation): + def __init__(self, domain, codomain, module=None): + if Transformation.check_codomain(domain, codomain): + if module is None: + if gdi.get('pyfftw') is None: + if gdi.get('gfft') is None: + self._transform =\ + GFFT(domain, codomain, gdi.get('gfft_dummy')) + else: + self._transform =\ + GFFT(domain, codomain, gdi.get('gfft')) + self._transform = FFTW(domain, codomain) + else: + if module == 'pyfftw': + if gdi.get('pyfftw') is not None: + self._transform = FFTW(domain, codomain) + else: + raise RuntimeError("ERROR: pyfftw is not available.") + elif module == 'gfft': + if gdi.get('gfft') is not None: + self._transform =\ + GFFT(domain, codomain, gdi.get('gfft')) + else: + raise RuntimeError("ERROR: gfft is not available.") + elif module == 'gfft_dummy': + self._transform =\ + GFFT(domain, codomain, gdi.get('gfft_dummy')) + else: + raise ValueError('Given FFT module is not known: ' + + str(module)) + else: + raise ValueError("ERROR: Incompatible codomain!") - codomain : nifty.rg.nifty_rg.rg_space - The taget into which the field should be transformed. - """ - raise NotImplementedError + def transform(self, val, axes=None, **kwargs): + return self._transform.transform(val, axes, **kwargs) \ No newline at end of file diff --git a/transformations/transformation_factory.py b/transformations/transformation_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca4dca3e4807c8c2410ca31b213a3871b5865f7 --- /dev/null +++ b/transformations/transformation_factory.py @@ -0,0 +1,29 @@ +import numpy as np + +from nifty.rg import RGSpace +from nifty.lm import GLSpace, HPSpace, LMSpace + +from transformation import RGRGTransformation + + +class TransformationFactory(object): + """ + Transform factory which generates transform objects + """ + + def __init__(self): + # cache for storing the transform objects + self.cache = {} + + def _get_transform(self, domain, codomain, module): + if isinstance(domain, RGSpace): + return RGRGTransformation(domain, codomain, module) + + def create(self, domain, codomain, module=None): + key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^ + (179 * module.__hash__())) + + if key not in self.cache: + self.cache[key] = self._get_transform(domain, codomain, module) + + return self.cache[key] \ No newline at end of file diff --git a/transforms/__init__.py b/transforms/__init__.py deleted file mode 100644 index 7cd28c54d939fe6b7f2c6f86413e88f7b8dfd93e..0000000000000000000000000000000000000000 --- a/transforms/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from transform_factory import TransformFactory - -tf = TransformFactory() diff --git a/transforms/transform_factory.py b/transforms/transform_factory.py deleted file mode 100644 index afd174fc7d4c38674369b636d249fd41266d6707..0000000000000000000000000000000000000000 --- a/transforms/transform_factory.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np - -from nifty.rg import RGSpace -from nifty.lm import GLSpace, HPSpace, LMSpace -from nifty.config import about, dependency_injector as gdi -from gfft import GFFT -from fftw import FFTW - - -class TransformFactory(object): - """ - Transform factory which generates transform objects - """ - - def __init__(self): - # cache for storing the transform objects - self.cache = {} - - def _get_transform(self, domain, codomain, module): - if isinstance(domain, RGSpace): - # fftw -> gfft -> gfft_dummy - if module is None: - if gdi.get('pyfftw') is None: - if gdi.get('gfft') is None: - return GFFT(domain, codomain, gdi.get('gfft_dummy')) - else: - return GFFT(domain, codomain, gdi.get('gfft')) - return FFTW(domain, codomain) - else: - if module == 'pyfftw': - if gdi.get('pyfftw') is not None: - return FFTW(domain, codomain) - else: - raise RuntimeError("ERROR: pyfftw is not available.") - elif module == 'gfft': - if gdi.get('gfft') is not None: - return GFFT(domain, codomain, gdi.get('gfft')) - else: - raise RuntimeError("ERROR: gfft is not available.") - elif module == 'gfft_dummy': - return GFFT(domain, codomain, gdi.get('gfft_dummy')) - else: - raise ValueError('Given FFT module is not known: ' + - str(module)) - - def create(self, domain, codomain, module=None): - key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^ - (179 * module.__hash__())) - - if key not in self.cache: - self.cache[key] = self._get_transform(domain, codomain, module) - - return self.cache[key] \ No newline at end of file