Commit 5a7e9f94 authored by Jait Dixit's avatar Jait Dixit
Browse files

Use reshaped centering masks instead of iterating over all slices

parent 614fcf1a
Pipeline #3755 skipped
......@@ -96,6 +96,9 @@ class FFTW(FFT):
# get_centering_mask
self.centering_mask_dict = {}
# Enable caching for pyfftw.interfaces
pyfftw.interfaces.cache.enable()
def get_centering_mask(self, to_center_input, dimensions_input,
offset_input=0):
"""
......@@ -251,17 +254,24 @@ class FFTW(FFT):
codomain.get_shape()
)
# reshape the masks
if axes:
domain_centering_mask = domain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
codomain_centering_mask = codomain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
# Apply codomain centering mask
for slice_list in utilities.get_slice_list(local_val.shape, axes):
if slice_list == [slice(None, None)]:
local_val *= codomain_centering_mask
else:
local_val[slice_list] *= codomain_centering_mask
local_val *= codomain_centering_mask
# We use pyfftw.interface.numpy_fft module to handle transformation
# in this case. The first call, with the given input combination,
# might be slow but the subsequent calls in the same session will
# be much faster.
# be much faster as caching is enabled.
if codomain.harmonic:
return_val = pyfftw.interfaces.numpy_fft.fftn(local_val,
axes=axes)
......@@ -270,11 +280,7 @@ class FFTW(FFT):
axes=axes)
# Apply domain centering mask
for slice_list in utilities.get_slice_list(local_val.shape, axes):
if slice_list == [slice(None, None)]:
return_val *= domain_centering_mask
else:
return_val[slice_list] *= domain_centering_mask
return_val *= domain_centering_mask
return return_val.astype(codomain.dtype)
else:
......@@ -316,10 +322,17 @@ class FFTW(FFT):
codomain.get_shape()
)
domain_centering_mask = domain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
codomain_centering_mask = codomain_centering_mask.reshape(
[y if x in axes else 1
for x, y in enumerate(local_val.shape)]
)
# Apply codomain centering mask
for slice_list in utilities.get_slice_list(local_val.shape,
axes):
local_val[slice_list] *= codomain_centering_mask
local_val *= codomain_centering_mask
if codomain.harmonic:
result = pyfftw.interfaces.numpy_fft.fftn(
......@@ -333,9 +346,7 @@ class FFTW(FFT):
)
# Apply domain centering mask
for slice_list in utilities.get_slice_list(local_val.shape,
axes):
result[slice_list] *= domain_centering_mask
local_val *= domain_centering_mask
# Push data in-place in the array to be returned
if return_val.distribution_strategy == 'fftw':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment