Commit 2057a238 authored by Jait Dixit's avatar Jait Dixit

WIP: More refactoring

- rename transforms submodule to transformations
- Add a Transformation class which wraps the actual transform class
- Derive and implement RGRGTransformation
- Add tests
parent 4c6ef2f0
...@@ -93,4 +93,4 @@ from demos import get_demo_dir ...@@ -93,4 +93,4 @@ from demos import get_demo_dir
from pickling import _pickle_method, _unpickle_method from pickling import _pickle_method, _unpickle_method
#import pyximport; pyximport.install(pyimport = True) #import pyximport; pyximport.install(pyimport = True)
from transforms import tf as transformator from transformations import tf as transformator
...@@ -34,7 +34,7 @@ setup(name="ift_nifty", ...@@ -34,7 +34,7 @@ setup(name="ift_nifty",
url="http://www.mpa-garching.mpg.de/ift/nifty/", url="http://www.mpa-garching.mpg.de/ift/nifty/",
packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm", packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm",
"nifty.operators", "nifty.dummys", "nifty.field_types", "nifty.operators", "nifty.dummys", "nifty.field_types",
"nifty.config", "nifty.power", "nifty.transforms"], "nifty.config", "nifty.power", "nifty.transformations"],
package_dir={"nifty": ""}, package_dir={"nifty": ""},
zip_safe=False, zip_safe=False,
dependency_links=[ dependency_links=[
......
...@@ -7,7 +7,7 @@ import itertools ...@@ -7,7 +7,7 @@ import itertools
from nifty import RGSpace, LMSpace, HPSpace, GLSpace from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator from nifty import transformator
from nifty.transforms.transform import Transform from nifty.transformations.transformation import Transformation
from nifty.rg.rg_space import gc as RG_GC from nifty.rg.rg_space import gc as RG_GC
import d2o import d2o
...@@ -29,6 +29,28 @@ for name in ['gfft', 'gfft_dummy', 'pyfftw']: ...@@ -29,6 +29,28 @@ for name in ['gfft', 'gfft_dummy', 'pyfftw']:
rg_fft_modules += [name] rg_fft_modules += [name]
rg_test_shapes = [(128, 128), (179, 179), (512, 512)]
rg_test_data = np.array(
[[0.38405405 + 0.32460996j, 0.02718878 + 0.08326207j,
0.78792080 + 0.81192595j, 0.17535687 + 0.68054781j,
0.93044845 + 0.71942995j, 0.21179999 + 0.00637665j],
[0.10905553 + 0.3027462j, 0.37361237 + 0.68434316j,
0.94070232 + 0.34129582j, 0.04658034 + 0.4575192j,
0.45057929 + 0.64297612j, 0.01007361 + 0.24953504j],
[0.39579662 + 0.70881906j, 0.01614435 + 0.82603832j,
0.84036344 + 0.50321592j, 0.87699553 + 0.40337862j,
0.11816016 + 0.43332373j, 0.76627757 + 0.66327959j],
[0.77272335 + 0.18277367j, 0.93341953 + 0.58105518j,
0.27227913 + 0.17458168j, 0.70204032 + 0.81397425j,
0.12422993 + 0.19215286j, 0.30897158 + 0.47364969j],
[0.24702012 + 0.54534373j, 0.55206013 + 0.98406613j,
0.57408167 + 0.55685406j, 0.87991341 + 0.52534323j,
0.93912604 + 0.97186519j, 0.77778942 + 0.45812051j],
[0.79367868 + 0.48149411j, 0.42484378 + 0.74870011j,
0.79611264 + 0.50926774j, 0.35372794 + 0.10468412j,
0.46140736 + 0.09449825j, 0.82044644 + 0.95992843j]])
############################################################################### ###############################################################################
class TestRGSpaceTransforms(unittest.TestCase): class TestRGSpaceTransforms(unittest.TestCase):
...@@ -45,6 +67,16 @@ class TestRGSpaceTransforms(unittest.TestCase): ...@@ -45,6 +67,16 @@ class TestRGSpaceTransforms(unittest.TestCase):
with assert_raises(TypeError): with assert_raises(TypeError):
transformator.create(x, y, module=module) transformator.create(x, y, module=module)
@parameterized.expand(
itertools.product([0, 1, 2], [None, (1, 1), (10, 10)], [False, True]),
testcase_func_name=custom_name_func
)
def test_check_codomain_rgspecific(self, complexity, distances, harmonic):
x = RGSpace((8, 8), complexity=complexity,
distances=distances, harmonic=harmonic)
assert(Transformation.check_codomain(x, x.get_codomain()))
assert(Transformation.check_codomain(x, x.get_codomain()))
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func) @parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_shapemismatch(self, module): def test_shapemismatch(self, module):
x = RGSpace((8, 8)) x = RGSpace((8, 8))
...@@ -55,7 +87,7 @@ class TestRGSpaceTransforms(unittest.TestCase): ...@@ -55,7 +87,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
).transform(b, axes=(0, 1, 2)) ).transform(b, axes=(0, 1, 2))
@parameterized.expand( @parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), itertools.product(rg_fft_modules, rg_test_shapes),
testcase_func_name=custom_name_func testcase_func_name=custom_name_func
) )
def test_local_ndarray(self, module, shape): def test_local_ndarray(self, module, shape):
...@@ -69,7 +101,7 @@ class TestRGSpaceTransforms(unittest.TestCase): ...@@ -69,7 +101,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
) )
@parameterized.expand( @parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), itertools.product(rg_fft_modules, rg_test_shapes),
testcase_func_name=custom_name_func testcase_func_name=custom_name_func
) )
def test_local_notzero(self, module, shape): def test_local_notzero(self, module, shape):
...@@ -84,7 +116,7 @@ class TestRGSpaceTransforms(unittest.TestCase): ...@@ -84,7 +116,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
) )
@parameterized.expand( @parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), itertools.product(rg_fft_modules, rg_test_shapes),
testcase_func_name=custom_name_func testcase_func_name=custom_name_func
) )
def test_not(self, module, shape): def test_not(self, module, shape):
...@@ -98,22 +130,35 @@ class TestRGSpaceTransforms(unittest.TestCase): ...@@ -98,22 +130,35 @@ class TestRGSpaceTransforms(unittest.TestCase):
np.fft.fftn(a) np.fft.fftn(a)
) )
# ndarray is not contiguous?
@parameterized.expand( @parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]), itertools.product(rg_test_shapes),
testcase_func_name=custom_name_func testcase_func_name=custom_name_func
) )
def test_mpi_axesnone(self, module, shape): def test_mpi_axesnone(self, shape):
x = RGSpace(shape) x = RGSpace(shape)
a = np.ones(shape) a = np.ones(shape)
b = d2o.distributed_data_object(a) b = d2o.distributed_data_object(a)
assert np.allclose( assert np.allclose(
transformator.create( transformator.create(
x, x.get_codomain(), module=module x, x.get_codomain(), module='pyfftw'
).transform(b),
np.fft.fftn(a)
)
@parameterized.expand(
itertools.product(rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_mpi_axesnone_equal(self, shape):
x = RGSpace(shape)
a = np.ones(shape)
b = d2o.distributed_data_object(a, distribution_strategy='equal')
assert np.allclose(
transformator.create(
x, x.get_codomain(), module='pyfftw'
).transform(b), ).transform(b),
np.fft.fftn(a) np.fft.fftn(a)
) )
#TODO: check what to do when cannot be distributed
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
from transformation_factory import TransformationFactory
tf = TransformationFactory()
...@@ -18,11 +18,8 @@ class FFTW(Transform): ...@@ -18,11 +18,8 @@ class FFTW(Transform):
""" """
def __init__(self, domain, codomain): def __init__(self, domain, codomain):
if Transform.check_codomain(domain, codomain):
self.domain = domain self.domain = domain
self.codomain = codomain self.codomain = codomain
else:
raise ValueError("ERROR: Invalid codomain!")
if 'pyfftw' not in gdi: if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.") raise ImportError("The module pyfftw is needed but not available.")
...@@ -363,7 +360,7 @@ class FFTW(Transform): ...@@ -363,7 +360,7 @@ class FFTW(Transform):
return return_val return return_val
def transform(self, val, axes=None, **kwargs): def transform(self, val, axes, **kwargs):
""" """
The pyfftw transform function. The pyfftw transform function.
...@@ -373,12 +370,6 @@ class FFTW(Transform): ...@@ -373,12 +370,6 @@ class FFTW(Transform):
The value-array of the field which is supposed to The value-array of the field which is supposed to
be transformed. be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The target into which the field should be transformed.
axes: tuple, None axes: tuple, None
The axes which should be transformed. The axes which should be transformed.
......
...@@ -17,14 +17,11 @@ class GFFT(Transform): ...@@ -17,14 +17,11 @@ class GFFT(Transform):
""" """
def __init__(self, domain, codomain, fft_module): def __init__(self, domain, codomain, fft_module):
if Transform.check_codomain(domain, codomain):
self.domain = domain self.domain = domain
self.codomain = codomain self.codomain = codomain
self.fft_machine = fft_module self.fft_machine = fft_module
else:
raise ValueError("ERROR: Invalid codomain!")
def transform(self, val, axes=None, **kwargs): def transform(self, val, axes, **kwargs):
""" """
The gfft transform function. The gfft transform function.
......
from nifty import RGSpace
from nifty.config import about
import numpy as np
class Transform(object):
"""
A generic fft object without any implementation.
"""
def __init__(self, domain, codomain):
pass
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
from fftw import FFTW
from gfft import GFFT
from nifty.config import about, dependency_injector as gdi
from nifty import RGSpace from nifty import RGSpace
from nifty.config import about
import numpy as np import numpy as np
class Transform(object): class Transformation(object):
""" """
A generic fft object without any implementation. A generic transformation which defines a static check_codomain
method for all transforms.
""" """
@staticmethod @staticmethod
...@@ -69,23 +73,45 @@ class Transform(object): ...@@ -69,23 +73,45 @@ class Transform(object):
return True return True
def __init__(self, domain, codomain): def __init__(self, domain, codomain, module=None):
pass pass
def transform(self, val, axes, **kwargs): def transform(self, val, axes=None, **kwargs):
""" raise NotImplementedError
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space class RGRGTransformation(Transformation):
The domain of the space which should be transformed. def __init__(self, domain, codomain, module=None):
if Transformation.check_codomain(domain, codomain):
if module is None:
if gdi.get('pyfftw') is None:
if gdi.get('gfft') is None:
self._transform =\
GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
self._transform =\
GFFT(domain, codomain, gdi.get('gfft'))
self._transform = FFTW(domain, codomain)
else:
if module == 'pyfftw':
if gdi.get('pyfftw') is not None:
self._transform = FFTW(domain, codomain)
else:
raise RuntimeError("ERROR: pyfftw is not available.")
elif module == 'gfft':
if gdi.get('gfft') is not None:
self._transform =\
GFFT(domain, codomain, gdi.get('gfft'))
else:
raise RuntimeError("ERROR: gfft is not available.")
elif module == 'gfft_dummy':
self._transform =\
GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('Given FFT module is not known: ' +
str(module))
else:
raise ValueError("ERROR: Incompatible codomain!")
codomain : nifty.rg.nifty_rg.rg_space def transform(self, val, axes=None, **kwargs):
The taget into which the field should be transformed. return self._transform.transform(val, axes, **kwargs)
""" \ No newline at end of file
raise NotImplementedError
...@@ -2,12 +2,11 @@ import numpy as np ...@@ -2,12 +2,11 @@ import numpy as np
from nifty.rg import RGSpace from nifty.rg import RGSpace
from nifty.lm import GLSpace, HPSpace, LMSpace from nifty.lm import GLSpace, HPSpace, LMSpace
from nifty.config import about, dependency_injector as gdi
from gfft import GFFT
from fftw import FFTW
from transformation import RGRGTransformation
class TransformFactory(object):
class TransformationFactory(object):
""" """
Transform factory which generates transform objects Transform factory which generates transform objects
""" """
...@@ -18,30 +17,7 @@ class TransformFactory(object): ...@@ -18,30 +17,7 @@ class TransformFactory(object):
def _get_transform(self, domain, codomain, module): def _get_transform(self, domain, codomain, module):
if isinstance(domain, RGSpace): if isinstance(domain, RGSpace):
# fftw -> gfft -> gfft_dummy return RGRGTransformation(domain, codomain, module)
if module is None:
if gdi.get('pyfftw') is None:
if gdi.get('gfft') is None:
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
return GFFT(domain, codomain, gdi.get('gfft'))
return FFTW(domain, codomain)
else:
if module == 'pyfftw':
if gdi.get('pyfftw') is not None:
return FFTW(domain, codomain)
else:
raise RuntimeError("ERROR: pyfftw is not available.")
elif module == 'gfft':
if gdi.get('gfft') is not None:
return GFFT(domain, codomain, gdi.get('gfft'))
else:
raise RuntimeError("ERROR: gfft is not available.")
elif module == 'gfft_dummy':
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('Given FFT module is not known: ' +
str(module))
def create(self, domain, codomain, module=None): def create(self, domain, codomain, module=None):
key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^ key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
......
from transform_factory import TransformFactory
tf = TransformFactory()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment