Skip to content
Snippets Groups Projects
Commit 38aa1c4b authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Fix FFTW caching

- Refactor RGRGTransformation
- Add is_local to dictionary key hash
parent 370fdd22
No related branches found
No related tags found
1 merge request!35Fix fftw transform
......@@ -170,7 +170,9 @@ class FFTW(Transform):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__hash__() ^
(101 * codomain.__hash__()) ^
(211 * transform_shape.__hash__()))
(211 * transform_shape.__hash__()) ^
(131 * is_local.__hash__())
)
# generate the plan_and_info object if not already there
if temp_id not in self.info_dict:
......
......@@ -7,29 +7,29 @@ from nifty import RGSpace, nifty_configuration
class RGRGTransformation(Transformation):
def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain, module)
super(RGRGTransformation, self).__init__(domain, codomain)
if module is None:
if nifty_configuration['fft_module'] == 'pyfftw':
self._transform = FFTW(domain, codomain)
self._transform = FFTW(self.domain, self.codomain)
elif (nifty_configuration['fft_module'] == 'gfft' or
nifty_configuration['fft_module'] == 'gfft_dummy'):
self._transform = \
GFFT(domain,
codomain,
GFFT(self.domain,
self.codomain,
gdi.get(nifty_configuration['fft_module']))
else:
raise ValueError('ERROR: unknow default FFT module:' +
nifty_configuration['fft_module'])
else:
if module == 'pyfftw':
self._transform = FFTW(domain, codomain)
self._transform = FFTW(self.domain, self.codomain)
elif module == 'gfft':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft'))
GFFT(self.domain, self.codomain, gdi.get('gfft'))
elif module == 'gfft_dummy':
self._transform = \
GFFT(domain, codomain, gdi.get('gfft_dummy'))
GFFT(self.domain, self.codomain, gdi.get('gfft_dummy'))
else:
raise ValueError('ERROR: unknow FFT module:' + module)
......
......@@ -11,7 +11,7 @@ class Transformation(object, Loggable):
"""
__metaclass__ = abc.ABCMeta
def __init__(self, domain, codomain=None, module=None):
def __init__(self, domain, codomain):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment