Commit 588237ae authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP FFTW Transform

- Refactor based on discussion on 08.06
- Add test for FFTW transform
parent 27bcdb96
Pipeline #4797 skipped
# -*- coding: utf-8 -*-
import numpy as np
from mpi4py import MPI
from d2o import distributed_data_object, STRATEGIES
from nifty.config import dependency_injector as gdi
from nifty.config import about
import nifty.nifty_utilities as utilities
pyfftw = gdi.get('pyfftw')
......@@ -207,11 +209,20 @@ class FFTW(FFT):
def _get_transform_info(self, domain, codomain, is_local=False, **kwargs):
# generate a id-tuple which identifies the domain-codomain setting
temp_id = domain.__hash__() ^ (101 * codomain.__hash__())
# generate the plan_and_info object if not already there
if temp_id not in self.info_dict:
self.info_dict[temp_id] = FFTWTransformInfo(domain, codomain,
is_local, self,
**kwargs)
if is_local:
self.info_dict[temp_id] = FFTWLocalTransformInfo(
domain, codomain,
self, **kwargs
)
else:
self.info_dict[temp_id] = FFTWMPITransfromInfo(
domain, codomain,
self, **kwargs
)
return self.info_dict[temp_id]
def _apply_mask(self, val, mask, axes):
......@@ -249,20 +260,35 @@ class FFTW(FFT):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, info.cmask_codomain, axes)
if info.local:
result = info.plan(val, axes=axes, planner_effort='FFTW_ESTIMATE')
result = info.fftw_interface(val, axes=axes,
planner_effort='FFTW_ESTIMATE')
# Apply domain centering mask
if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
result = self._apply_mask(result, info.cmask_domain, axes)
# Correct the sign if needed
result *= info.sign
return result
def _mpi_transform(self, val, info, axes, domain, codomain):
# Apply codomain centering mask
if reduce(lambda x, y: x+y, codomain.paradict['zerocenter']):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, info.cmask_codomain, axes)
p = info.plan
# Load the value into the plan
if p.has_input:
p.input_array[:] = val
# Execute the plan
p()
if p.has_output:
result = p.output_array
else:
p = info.plan
# Load the value into the plan
if p.has_input:
p.input_array[:] = val
# Execute the plan
p()
if p.has_output:
result = p.output_array
else:
raise RuntimeError('ERROR: PyFFTW-MPI transform failed.')
raise RuntimeError('ERROR: PyFFTW-MPI transform failed.')
# Apply domain centering mask
if reduce(lambda x, y: x+y, domain.paradict['zerocenter']):
......@@ -307,66 +333,80 @@ class FFTW(FFT):
# If the input is a numpy array we transform it locally
if not isinstance(val, distributed_data_object):
# Cast to a np.ndarray
temp_val = np.asarray(val)
current_info = self._get_transform_info(domain, codomain,
is_local=True,
**kwargs)
# local transform doesn't apply transforms inplace
return_val = self._local_transform(val, current_info, axes,
return_val = self._local_transform(temp_val, current_info, axes,
domain, codomain)
else:
if val.distribution_strategy in STRATEGIES['slicing']:
if axes is None or set(axes) == set(range(len(val.shape))):
current_info = self._get_transform_info(domain, codomain,
**kwargs)
if val.comm is not MPI.COMM_WORLD:
raise RuntimeError('ERROR: Input array uses an unsupported \
comm object')
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Compute transform for the local data
result = self._local_transform(
val.get_local_data(copy=False),
current_info, axes,
domain, codomain
)
# Insert result into the return_val
return_val.set_local_data(data=result, copy=False)
elif 0 in axes:
current_info = self._get_transform_info(domain, codomain,
**kwargs)
# Repack val to fftw
new_val = val.copy(distribution_strategy='fftw',
copy=False)
return_val = new_val.copy_empty(global_shape=new_val.shape,
if val.distribution_strategy in STRATEGIES['slicing']:
if axes is None or set(axes) == set(range(len(val.shape))) \
or 0 in axes:
if val.distribution_strategy != 'fftw':
temp_val = val.copy_empty(distribution_strategy='fftw')
temp_val.set_full_data(val, copy=False)
# Recursive call to transform
result = self.transform(temp_val, domain, codomain,
axes, **kwargs)
return_val = result.copy_empty(
distribution_strategy=val.distribution_strategy
)
return_val.set_full_data(data=result, copy=False)
else:
current_info = self._get_transform_info(domain,
codomain,
**kwargs)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Extract local data
local_val = val.get_local_data(copy=False)
# Intermediate storage for individual slices
temp_val = np.empty_like(local_val)
for slice_list in utilities.get_slice_list(local_val.shape,
axes):
inp = local_val[slice_list]
result = self._local_transform(inp, current_info,
axes, domain, codomain)
temp_val[slice_list] = result
return_val.set_local_data(data=temp_val, copy=False)
# Repack to original distribution strategy
return_val = return_val.copy(
distribution_strategy=val.distribution_strategy,
copy=False
)
# Extract local data
local_val = val.get_local_data(copy=False)
# Create temporary storage for slices
temp_val = None
# If axes tuple includes all axes, set it to None
if set(axes) == set(range(len(val.shape))):
axes = None
for slice_list in\
utilities.get_slice_list(local_val.shape,
axes):
if slice_list == [slice(None, None)]:
inp = local_val
else:
if temp_val is None:
temp_val = np.empty_like(local_val)
inp = local_val[slice_list]
result = self._mpi_transform(inp,
current_info,
axes, domain,
codomain)
if slice_list == [slice(None, None)]:
temp_val = result
else:
temp_val[slice_list] = result
return_val.set_local_data(data=temp_val, copy=False)
else:
current_info = self._get_transform_info(domain, codomain,
is_local=True,
**kwargs)
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
# Compute transform for the local data
result = self._local_transform(
val.get_local_data(copy=False),
......@@ -374,42 +414,43 @@ class FFTW(FFT):
domain, codomain
)
# Insert result inplace in return_val
# Create return object and insert results inplace
return_val = val.copy_empty(global_shape=val.shape,
dtype=codomain.dtype)
return_val.set_local_data(data=result, copy=False)
# If domain is purely real, the result of the FFT is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
else:
new_val = val.copy(distribution_strategy='fftw',
copy=False)
temp_val = val.copy_empty(distribution_strategy='fftw')
about.warnings.cprint('WARNING: Repacking d2o to fftw \
distribution strategy')
temp_val.set_full_data(val, copy=False)
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
return_val = self.transform(new_val, domain, codomain, axes,
**kwargs)
result = self.transform(temp_val, domain, codomain, axes,
**kwargs)
return_val = return_val.copy(
distribution_strategy=val.distribution_strategy,
copy=False
return_val = val.copy_empty(
distribution_strategy=val.distribution_strategy
)
return_val.set_full_data(result, copy=False)
# If domain is purely real, the result of the FFT is hermitian
if domain.paradict['complexity'] == 0:
return_val.hermitian = True
return return_val
return return_val
class FFTWTransformInfo(object):
def __init__(self, domain, codomain, is_local, fftw_context, **kwargs):
def __init__(self, domain, codomain, fftw_context, **kwargs):
if pyfftw is None:
raise ImportError("The module pyfftw is needed but not available.")
# Whether it's a local or an MPI transformation
self.local = is_local
# Create domain centering mask. The mask is the same size
# as the domain's shape.
self.cmask_domain = fftw_context.get_centering_mask(
......@@ -432,24 +473,6 @@ class FFTWTransformInfo(object):
(np.array(domain.get_shape()) // 2 % 2)
)
# Create the plans. If it's a local transform store pyfftw.interfaces
# instance otherwise an MPI plan
if self.local:
if codomain.harmonic:
self.plan = pyfftw.interfaces.numpy_fft.fftn
else:
self.plan = pyfftw.interfaces.numpy_fftn.ifftn
else:
self.plan = pyfftw.create_mpi_plan(
input_shape=domain.get_shape(),
input_dtype='complex128',
output_dtype='complex128',
direction='FFTW_FORWARD'
if codomain.harmonic else 'FFTW_BACKWARD',
flags=["FFTW_ESTIMATE"],
**kwargs
)
@property
def cmask_domain(self):
return self._domain_centering_mask
......@@ -467,12 +490,45 @@ class FFTWTransformInfo(object):
self._codomain_centering_mask = cmask
@property
def local(self):
return self._local
def sign(self):
return self._sign
@sign.setter
def sign(self, sign):
self._sign = sign
@local.setter
def local(self, local):
self._local = local
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, fftw_context, **kwargs):
super(FFTWLocalTransformInfo, self).__init__(domain, codomain,
fftw_context, **kwargs)
if codomain.harmonic:
self._fftw_interface = pyfftw.interfaces.numpy_fft.fftn
else:
self._fftw_interface = pyfftw.interfaces.numpy_fftn.ifftn
@property
def fftw_interface(self):
return self._fftw_interface
@fftw_interface.setter
def fftw_interface(self, interface):
about.warnings.cprint('WARNING: FFTWLocalTransformInfo fftw_interface \
cannot be modified')
class FFTWMPITransfromInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, fftw_context, **kwargs):
super(FFTWMPITransfromInfo, self).__init__(domain, codomain,
fftw_context, **kwargs)
self._plan = pyfftw.create_mpi_plan(
input_shape=domain.get_shape(),
input_dtype='complex128',
output_dtype='complex128',
direction='FFTW_FORWARD' if codomain.harmonic else 'FFTW_BACKWARD',
flags=["FFTW_ESTIMATE"],
**kwargs
)
@property
def plan(self):
......@@ -480,15 +536,8 @@ class FFTWTransformInfo(object):
@plan.setter
def plan(self, plan):
self._plan = plan
@property
def sign(self):
return self._sign
@sign.setter
def sign(self, sign):
self._sign = sign
about.warnings.cprint('WARNING: FFTWMPITransfromInfo plan \
cannot be modified')
class GFFT(FFT):
......
import nifty as nt
import numpy as np
import unittest
import d2o
class TestFFTWTransform(unittest.TestCase):
def test_comm(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a)
b.comm = [1, 2, 3] # change comm to something not supported
with self.assertRaises(RuntimeError):
x.fft_machine.transform(b, x, x.get_codomain())
def test_shapemismatch(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
with self.assertRaises(ValueError):
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1, 2))
def test_local_ndarray(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
self.assertTrue(
np.allclose(
x.fft_machine.transform(a, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_local_notzero(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(1,)),
np.fft.fftn(a, axes=(1,))
), 'results do not match numpy.fft.fftn'
)
def test_not(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='not')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesnone(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesnone_equal(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain()),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesall(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1)),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_axesall_equal(self):
x = nt.rg_space((8, 8), fft_module='pyfftw')
a = np.ones((8, 8))
b = d2o.distributed_data_object(a, distribution_strategy='equal')
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0, 1)),
np.fft.fftn(a)
), 'results do not match numpy.fft.fftn'
)
def test_mpi_zero(self):
x = nt.rg_space(8, fft_module='pyfftw')
a = np.ones((8, 8)) + 1j*np.zeros((8, 8))
b = x.cast(a)
self.assertTrue(
np.allclose(
x.fft_machine.transform(b, x, x.get_codomain(), axes=(0,)),
np.fft.fftn(a, axes=(0,))
), 'results do not match numpy.fft.fftn'
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment