Commit 2057a238 authored by Jait Dixit's avatar Jait Dixit

WIP: More refactoring

- rename transforms submodule to transformations
- Add a Transformation class which wraps the actual transform class
- Derive and implement RGRGTransformation
- Add tests
parent 4c6ef2f0
......@@ -93,4 +93,4 @@ from demos import get_demo_dir
from pickling import _pickle_method, _unpickle_method
#import pyximport; pyximport.install(pyimport = True)
from transforms import tf as transformator
from transformations import tf as transformator
......@@ -34,7 +34,7 @@ setup(name="ift_nifty",
url="http://www.mpa-garching.mpg.de/ift/nifty/",
packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm",
"nifty.operators", "nifty.dummys", "nifty.field_types",
"nifty.config", "nifty.power", "nifty.transforms"],
"nifty.config", "nifty.power", "nifty.transformations"],
package_dir={"nifty": ""},
zip_safe=False,
dependency_links=[
......
......@@ -7,7 +7,7 @@ import itertools
from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator
from nifty.transforms.transform import Transform
from nifty.transformations.transformation import Transformation
from nifty.rg.rg_space import gc as RG_GC
import d2o
......@@ -29,6 +29,28 @@ for name in ['gfft', 'gfft_dummy', 'pyfftw']:
rg_fft_modules += [name]
rg_test_shapes = [(128, 128), (179, 179), (512, 512)]
rg_test_data = np.array(
[[0.38405405 + 0.32460996j, 0.02718878 + 0.08326207j,
0.78792080 + 0.81192595j, 0.17535687 + 0.68054781j,
0.93044845 + 0.71942995j, 0.21179999 + 0.00637665j],
[0.10905553 + 0.3027462j, 0.37361237 + 0.68434316j,
0.94070232 + 0.34129582j, 0.04658034 + 0.4575192j,
0.45057929 + 0.64297612j, 0.01007361 + 0.24953504j],
[0.39579662 + 0.70881906j, 0.01614435 + 0.82603832j,
0.84036344 + 0.50321592j, 0.87699553 + 0.40337862j,
0.11816016 + 0.43332373j, 0.76627757 + 0.66327959j],
[0.77272335 + 0.18277367j, 0.93341953 + 0.58105518j,
0.27227913 + 0.17458168j, 0.70204032 + 0.81397425j,
0.12422993 + 0.19215286j, 0.30897158 + 0.47364969j],
[0.24702012 + 0.54534373j, 0.55206013 + 0.98406613j,
0.57408167 + 0.55685406j, 0.87991341 + 0.52534323j,
0.93912604 + 0.97186519j, 0.77778942 + 0.45812051j],
[0.79367868 + 0.48149411j, 0.42484378 + 0.74870011j,
0.79611264 + 0.50926774j, 0.35372794 + 0.10468412j,
0.46140736 + 0.09449825j, 0.82044644 + 0.95992843j]])
###############################################################################
class TestRGSpaceTransforms(unittest.TestCase):
......@@ -45,6 +67,16 @@ class TestRGSpaceTransforms(unittest.TestCase):
with assert_raises(TypeError):
transformator.create(x, y, module=module)
@parameterized.expand(
itertools.product([0, 1, 2], [None, (1, 1), (10, 10)], [False, True]),
testcase_func_name=custom_name_func
)
def test_check_codomain_rgspecific(self, complexity, distances, harmonic):
x = RGSpace((8, 8), complexity=complexity,
distances=distances, harmonic=harmonic)
assert(Transformation.check_codomain(x, x.get_codomain()))
assert(Transformation.check_codomain(x, x.get_codomain()))
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_shapemismatch(self, module):
x = RGSpace((8, 8))
......@@ -55,7 +87,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
).transform(b, axes=(0, 1, 2))
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
itertools.product(rg_fft_modules, rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_local_ndarray(self, module, shape):
......@@ -69,7 +101,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
)
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
itertools.product(rg_fft_modules, rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_local_notzero(self, module, shape):
......@@ -84,7 +116,7 @@ class TestRGSpaceTransforms(unittest.TestCase):
)
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
itertools.product(rg_fft_modules, rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_not(self, module, shape):
......@@ -98,22 +130,35 @@ class TestRGSpaceTransforms(unittest.TestCase):
np.fft.fftn(a)
)
# ndarray is not contiguous?
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
itertools.product(rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_mpi_axesnone(self, module, shape):
def test_mpi_axesnone(self, shape):
x = RGSpace(shape)
a = np.ones(shape)
b = d2o.distributed_data_object(a)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
x, x.get_codomain(), module='pyfftw'
).transform(b),
np.fft.fftn(a)
)
@parameterized.expand(
itertools.product(rg_test_shapes),
testcase_func_name=custom_name_func
)
def test_mpi_axesnone_equal(self, shape):
x = RGSpace(shape)
a = np.ones(shape)
b = d2o.distributed_data_object(a, distribution_strategy='equal')
assert np.allclose(
transformator.create(
x, x.get_codomain(), module='pyfftw'
).transform(b),
np.fft.fftn(a)
)
#TODO: check what to do when cannot be distributed
if __name__ == '__main__':
unittest.main()
from transformation_factory import TransformationFactory
tf = TransformationFactory()
......@@ -18,11 +18,8 @@ class FFTW(Transform):
"""
def __init__(self, domain, codomain):
if Transform.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Invalid codomain!")
self.domain = domain
self.codomain = codomain
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
......@@ -363,7 +360,7 @@ class FFTW(Transform):
return return_val
def transform(self, val, axes=None, **kwargs):
def transform(self, val, axes, **kwargs):
"""
The pyfftw transform function.
......@@ -373,12 +370,6 @@ class FFTW(Transform):
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The target into which the field should be transformed.
axes: tuple, None
The axes which should be transformed.
......
......@@ -17,14 +17,11 @@ class GFFT(Transform):
"""
def __init__(self, domain, codomain, fft_module):
if Transform.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
else:
raise ValueError("ERROR: Invalid codomain!")
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
def transform(self, val, axes=None, **kwargs):
def transform(self, val, axes, **kwargs):
"""
The gfft transform function.
......
from nifty import RGSpace
from nifty.config import about
import numpy as np
class Transform(object):
"""
A generic fft object without any implementation.
"""
def __init__(self, domain, codomain):
pass
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
from fftw import FFTW
from gfft import GFFT
from nifty.config import about, dependency_injector as gdi
from nifty import RGSpace
from nifty.config import about
import numpy as np
class Transform(object):
class Transformation(object):
"""
A generic fft object without any implementation.
A generic transformation which defines a static check_codomain
method for all transforms.
"""
@staticmethod
......@@ -69,23 +73,45 @@ class Transform(object):
return True
def __init__(self, domain, codomain):
def __init__(self, domain, codomain, module=None):
pass
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
def transform(self, val, axes=None, **kwargs):
raise NotImplementedError
Parameters
----------
field_val : distributed_data_object
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain, module=None):
if Transformation.check_codomain(domain, codomain):
if module is None:
if gdi.get('pyfftw') is None:
if gdi.get('gfft') is None:
self._transform =\
GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
self._transform =\
GFFT(domain, codomain, gdi.get('gfft'))
self._transform = FFTW(domain, codomain)
else:
if module == 'pyfftw':
if gdi.get('pyfftw') is not None:
self._transform = FFTW(domain, codomain)
else:
raise RuntimeError("ERROR: pyfftw is not available.")
elif module == 'gfft':
if gdi.get('gfft') is not None:
self._transform =\
GFFT(domain, codomain, gdi.get('gfft'))
else:
raise RuntimeError("ERROR: gfft is not available.")
elif module == 'gfft_dummy':
self._transform =\
GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('Given FFT module is not known: ' +
str(module))
else:
raise ValueError("ERROR: Incompatible codomain!")
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
raise NotImplementedError
def transform(self, val, axes=None, **kwargs):
return self._transform.transform(val, axes, **kwargs)
\ No newline at end of file
......@@ -2,12 +2,11 @@ import numpy as np
from nifty.rg import RGSpace
from nifty.lm import GLSpace, HPSpace, LMSpace
from nifty.config import about, dependency_injector as gdi
from gfft import GFFT
from fftw import FFTW
from transformation import RGRGTransformation
class TransformFactory(object):
class TransformationFactory(object):
"""
Transform factory which generates transform objects
"""
......@@ -18,30 +17,7 @@ class TransformFactory(object):
def _get_transform(self, domain, codomain, module):
if isinstance(domain, RGSpace):
# fftw -> gfft -> gfft_dummy
if module is None:
if gdi.get('pyfftw') is None:
if gdi.get('gfft') is None:
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
return GFFT(domain, codomain, gdi.get('gfft'))
return FFTW(domain, codomain)
else:
if module == 'pyfftw':
if gdi.get('pyfftw') is not None:
return FFTW(domain, codomain)
else:
raise RuntimeError("ERROR: pyfftw is not available.")
elif module == 'gfft':
if gdi.get('gfft') is not None:
return GFFT(domain, codomain, gdi.get('gfft'))
else:
raise RuntimeError("ERROR: gfft is not available.")
elif module == 'gfft_dummy':
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('Given FFT module is not known: ' +
str(module))
return RGRGTransformation(domain, codomain, module)
def create(self, domain, codomain, module=None):
key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
......
from transform_factory import TransformFactory
tf = TransformFactory()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment