Commit 4c6ef2f0 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Refactoring

- Merge _get_transform_override and _get_transform
- Move check_codomain for rg_space to Transform as a staticmethod
- Add tests for transforms
parent 0d492a58
import nifty as nt
import numpy as np
import unittest
import d2o
class TestFFTWTransform(unittest.TestCase):
def test_comm(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a)
b.comm = [1, 2, 3] # change comm to something not supported
with self.assertRaises(RuntimeError):
x.fft_machine.transform(b, x, x.get_codomain())
def test_shapemismatch(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
with self.assertRaises(ValueError):
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1, 2))
def test_local_ndarray(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
self.assertTrue(
np.allclose(
x.fft_machine.transform(a, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_local_notzero(self):
x = nt.RGSpace(8, fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(1,)),
np.fft.fftn(a, axes=(1,))
), 'results do not match numpy.fft.fftn'
)
def test_not(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='not')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesnone(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesnone_equal(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesall(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1)),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesall_equal(self):
x = nt.RGSpace((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1)),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero(self):
x = nt.RGSpace(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero_equal(self):
x = nt.RGSpace(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero_not(self):
x = nt.RGSpace(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='not')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
if __name__ == '__main__':
unittest.main()
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal, assert_raises
from nose_parameterized import parameterized
import unittest
import itertools
from nifty import RGSpace, LMSpace, HPSpace, GLSpace
from nifty import transformator
from nifty.transforms.transform import Transform
from nifty.rg.rg_space import gc as RG_GC
import d2o
###############################################################################
def custom_name_func(testcase_func, param_num, param):
return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name("_".join(str(x) for x in param.args)),
)
###############################################################################
rg_fft_modules = []
for name in ['gfft', 'gfft_dummy', 'pyfftw']:
if RG_GC.validQ('fft_module', name):
rg_fft_modules += [name]
###############################################################################
class TestRGSpaceTransforms(unittest.TestCase):
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_check_codomain_none(self, module):
x = RGSpace((8, 8))
with assert_raises(ValueError):
transformator.create(x, None, module=module)
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_check_codomain_mismatch(self, module):
x = RGSpace((8, 8))
y = LMSpace(8)
with assert_raises(TypeError):
transformator.create(x, y, module=module)
@parameterized.expand(rg_fft_modules, testcase_func_name=custom_name_func)
def test_shapemismatch(self, module):
x = RGSpace((8, 8))
b = d2o.distributed_data_object(np.ones((8, 8)))
with assert_raises(ValueError):
transformator.create(
x, x.get_codomain(), module=module
).transform(b, axes=(0, 1, 2))
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
testcase_func_name=custom_name_func
)
def test_local_ndarray(self, module, shape):
x = RGSpace(shape)
a = np.ones(shape)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
).transform(a),
np.fft.fftn(a)
)
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
testcase_func_name=custom_name_func
)
def test_local_notzero(self, module, shape):
x = RGSpace(shape[0]) # all tests along axis 1
a = np.ones(shape)
b = d2o.distributed_data_object(a)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
).transform(b, axes=(1,)),
np.fft.fftn(a, axes=(1,))
)
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
testcase_func_name=custom_name_func
)
def test_not(self, module, shape):
x = RGSpace(shape)
a = np.ones(shape)
b = d2o.distributed_data_object(a, distribution_strategy='not')
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
).transform(b),
np.fft.fftn(a)
)
# ndarray is not contiguous?
@parameterized.expand(
itertools.product(rg_fft_modules, [(128, 128), (179, 179), (512, 512)]),
testcase_func_name=custom_name_func
)
def test_mpi_axesnone(self, module, shape):
x = RGSpace(shape)
a = np.ones(shape)
b = d2o.distributed_data_object(a)
assert np.allclose(
transformator.create(
x, x.get_codomain(), module=module
).transform(b),
np.fft.fftn(a)
)
#TODO: check what to do when cannot be distributed
if __name__ == '__main__':
unittest.main()
......@@ -4,27 +4,25 @@ import numpy as np
from d2o import distributed_data_object, STRATEGIES
from nifty.config import about, dependency_injector as gdi
import nifty.nifty_utilities as utilities
from transform import FFT
from transform import Transform
from mpi4py import MPI
pyfftw = gdi.get('pyfftw')
class FFTW(FFT):
class FFTW(Transform):
"""
The pyfftw pendant of a fft object.
"""
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
info_dict = {}
# initialize the dictionary which stores the values from
# get_centering_mask
centering_mask_dict = {}
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
if Transform.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
else:
raise ValueError("ERROR: Invalid codomain!")
if 'pyfftw' not in gdi:
raise ImportError("The module pyfftw is needed but not available.")
......@@ -34,8 +32,15 @@ class FFTW(FFT):
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
@classmethod
def get_centering_mask(cls, to_center_input, dimensions_input,
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
self.info_dict = {}
# initialize the dictionary which stores the values from
# get_centering_mask
self.centering_mask_dict = {}
def get_centering_mask(self, to_center_input, dimensions_input,
offset_input=False):
"""
Computes the mask, used to (de-)zerocenter domain and target
......@@ -97,7 +102,7 @@ class FFTW(FFT):
# compute an identifier for the parameter set
temp_id = tuple(
(tuple(to_center), tuple(dimensions), tuple(offset)))
if temp_id not in cls.centering_mask_dict:
if temp_id not in self.centering_mask_dict:
# use np.tile in order to stack the core alternation scheme
# until the desired format is constructed.
core = np.fromfunction(
......@@ -135,11 +140,10 @@ class FFTW(FFT):
else:
temp_slice += (slice(None),)
centering_mask = centering_mask[temp_slice]
cls.centering_mask_dict[temp_id] = centering_mask
return cls.centering_mask_dict[temp_id]
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
@classmethod
def _get_transform_info(cls, domain, codomain, local_shape,
def _get_transform_info(self, domain, codomain, local_shape,
local_offset_Q, is_local, transform_shape=None,
**kwargs):
# generate a id-tuple which identifies the domain-codomain setting
......@@ -148,19 +152,19 @@ class FFTW(FFT):
(211 * transform_shape.__hash__()))
# generate the plan_and_info object if not already there
if temp_id not in cls.info_dict:
if temp_id not in self.info_dict:
if is_local:
cls.info_dict[temp_id] = FFTWLocalTransformInfo(
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain, local_shape,
local_offset_Q, **kwargs
local_offset_Q, self, **kwargs
)
else:
cls.info_dict[temp_id] = FFTWMPITransfromInfo(
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain, local_shape,
local_offset_Q, transform_shape, **kwargs
local_offset_Q, self, transform_shape, **kwargs
)
return cls.info_dict[temp_id]
return self.info_dict[temp_id]
def _apply_mask(self, val, mask, axes):
"""
......@@ -192,6 +196,7 @@ class FFTW(FFT):
return val * mask
def _atomic_mpi_transform(self, val, info, axes):
# Apply codomain centering mask
if reduce(lambda x, y: x+y, self.codomain.paradict['zerocenter']):
temp_val = np.copy(val)
......@@ -297,6 +302,7 @@ class FFTW(FFT):
return_val = val.copy_empty(global_shape=val.shape,
dtype=self.codomain.dtype)
# Extract local data
local_val = val.get_local_data(copy=False)
......@@ -334,7 +340,7 @@ class FFTW(FFT):
local_shape=val.local_shape,
local_offset_Q=local_offset_Q,
is_local=False,
transform_shape=inp.shape,
transform_shape=val.shape, # TODO: check why inp.shape doesn't work
**kwargs
)
......@@ -427,16 +433,16 @@ class FFTW(FFT):
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, local_shape,
local_offset_Q, **kwargs):
local_offset_Q, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
self.cmask_domain = FFTW.get_centering_mask(
self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
local_shape,
local_offset_Q)
self.cmask_codomain = FFTW.get_centering_mask(
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'],
local_shape,
local_offset_Q)
......@@ -475,11 +481,12 @@ class FFTWTransformInfo(object):
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape,
local_offset_Q, **kwargs):
local_offset_Q, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain,
codomain,
local_shape,
local_offset_Q,
fftw_context,
**kwargs)
if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
......@@ -499,11 +506,12 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, local_shape,
local_offset_Q, transform_shape, **kwargs):
local_offset_Q, fftw_context, transform_shape, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain,
codomain,
local_shape,
local_offset_Q,
fftw_context,
**kwargs)
self._plan = pyfftw.create_mpi_plan(
input_shape=transform_shape,
......
import numpy as np
from transform import FFT
from transform import Transform
from d2o import distributed_data_object
import nifty.nifty_utilities as utilities
class GFFT(FFT):
class GFFT(Transform):
"""
The gfft pendant of a fft object.
......@@ -17,9 +17,12 @@ class GFFT(FFT):
"""
def __init__(self, domain, codomain, fft_module):
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
if Transform.check_codomain(domain, codomain):
self.domain = domain
self.codomain = codomain
self.fft_machine = fft_module
else:
raise ValueError("ERROR: Invalid codomain!")
def transform(self, val, axes=None, **kwargs):
"""
......
class FFT(object):
from nifty import RGSpace
from nifty.config import about
import numpy as np
class Transform(object):
"""
A generic fft object without any implementation.
"""
def __init__(self):
@staticmethod
def check_codomain(domain, codomain):
if codomain is None:
return False
if isinstance(domain, RGSpace):
if not isinstance(codomain, RGSpace):
raise TypeError(about._errors.cstring(
"ERROR: The given codomain must be a rg_space."
))
if not np.all(np.array(domain.paradict['shape']) ==
np.array(codomain.paradict['shape'])):
return False
if domain.harmonic == codomain.harmonic:
return False
# Check complexity
# Prepare shorthands
dcomp = domain.paradict['complexity']
cocomp = codomain.paradict['complexity']
# Case 1: if domain is completely complex, the codomain
# must be complex too
if dcomp == 2:
if cocomp != 2:
return False
# Case 2: if domain is hermitian, the codomain can be
# real, a warning is raised otherwise
elif dcomp == 1:
if cocomp > 0:
about.warnings.cprint(
"WARNING: Unrecommended codomain! " +
"The domain is hermitian, hence the" +
"codomain should be restricted to real values."
)
# Case 3: if domain is real, the codomain should be hermitian
elif dcomp == 0:
if cocomp == 2:
about.warnings.cprint(
"WARNING: Unrecommended codomain! " +
"The domain is real, hence the" +
"codomain should be restricted to" +
"hermitian configuration."
)
elif cocomp == 0:
return False
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(
np.absolute(np.array(domain.paradict['shape']) *
np.array(domain.distances) *
np.array(codomain.distances) - 1) < domain.epsilon):
return False
else:
return False
return True
def __init__(self, domain, codomain):
pass
def transform(self, val, domain, codomain, axes, **kwargs):
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
......
import numpy as np
from nifty.rg import RGSpace
from nifty.lm import GLSpace, HPSpace, LMSpace
from nifty.config import dependency_injector as gdi
from nifty.config import about, dependency_injector as gdi
from gfft import GFFT
from fftw import FFTW
......@@ -14,33 +16,38 @@ class TransformFactory(object):
# cache for storing the transform objects
self.cache = {}
def _get_transform_override(self, domain, codomain, module):
if module == 'gfft':
return GFFT(domain, codomain, gdi.get('gfft'))
elif module == 'fftw':
return FFTW(domain, codomain)
elif module == 'gfft_dummmy':
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
def _get_transform(self, domain, codomain):
if isinstance(domain, RGSpace) and isinstance(codomain, RGSpace):
def _get_transform(self, domain, codomain, module):
if isinstance(domain, RGSpace):
# fftw -> gfft -> gfft_dummy
if gdi.get('fftw') is None:
if gdi.get('gfft') is None:
if module is None:
if gdi.get('pyfftw') is None:
if gdi.get('gfft') is None:
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
return GFFT(domain, codomain, gdi.get('gfft'))
return FFTW(domain, codomain)
else:
if module == 'pyfftw':
if gdi.get('pyfftw') is not None:
return FFTW(domain, codomain)
else:
raise RuntimeError("ERROR: pyfftw is not available.")
elif module == 'gfft':
if gdi.get('gfft') is not None:
return GFFT(domain, codomain, gdi.get('gfft'))
else:
raise RuntimeError("ERROR: gfft is not available.")
elif module == 'gfft_dummy':
return GFFT(domain, codomain, gdi.get('gfft_dummy'))
else:
return GFFT(domain, codomain, gdi.get('gfft'))
return FFTW(domain, codomain)
raise ValueError('Given FFT module is not known: ' +
str(module))
def create(self, domain, codomain, module=None):
key = domain.__hash__() ^ ((111 * codomain.__hash__()) ^
(179 * module.__hash__()))
if key not in self.cache:
if module is None:
self.cache[key] = self._get_transform(domain, codomain)
else:
self.cache[key] = self._get_transform_override(domain,
codomain,
module)
return self.cache[key]
self.cache[key] = self._get_transform(domain, codomain, module)
return self.cache[key]
\ No newline at end of file