Commit 2ae2d181 authored by Jait Dixit's avatar Jait Dixit
Browse files

Fixes from comments on the merge request

- get_slice_list returns the entire shape when axes is None
- GFFT.transform() is now a single for-loop
- Input array is simply cast to codomain's dtype
- Other formating changes were included/reverted
parent 81ed3ae3
Pipeline #3259 skipped
......@@ -33,22 +33,25 @@ def get_slice_list(shape, axes):
if not shape:
raise ValueError(about._errors.cstring("ERROR: shape cannot be None."))
if not all(axis < len(shape) for axis in axes):
raise ValueError(
about._errors.cstring("ERROR: axes(axis) does not match shape.")
)
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables = [range(y) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
yield slice_list
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError(
about._errors.cstring("ERROR: \
axes(axis) does not match shape.")
)
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_iterables =\
[range(y) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
it_iter = iter(index)
slice_list = [
next(it_iter)
if axis else slice(None, None) for axis in axes_select
]
yield slice_list
else:
yield [slice(None, None)]
return
def hermitianize_gaussian(x):
......
......@@ -162,17 +162,14 @@ class FFTW(FFT):
# use np.tile in order to stack the core alternation scheme
# until the desired format is constructed.
core = np.fromfunction(
lambda *args: (-1) ** (
np.tensordot(
to_center,
args + offset.reshape(
offset.shape + (1,) * (np.array(args).ndim - 1)
),
1
)
),
(2,) * to_center.size
)
lambda *args: (-1) **
(np.tensordot(to_center,
args +
offset.reshape(offset.shape +
(1,) *
(np.array(args).ndim - 1)),
1)),
(2,) * to_center.size)
# Cast the core to the smallest integers we can get
core = core.astype(np.int8)
......@@ -436,55 +433,42 @@ class GFFT(FFT):
result : np.ndarray
Fourier-transformed pendant of the input field.
"""
# GFFT is dumb. The entire input needs to be present on the node.
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if isinstance(val, distributed_data_object):
temp = val.get_full_data()
else:
temp = val
# Cast input datatype to complex
if domain.dtype == np.float64:
temp = temp.astype(np.complex128)
# Cast input datatype to codomain's dtype
temp = temp.astype(codomain.dtype)
# Result is generated and stored in a local numpy array
# Array for storing the result
return_val = np.empty_like(temp)
if axes:
for slice_list in utilities.get_slice_list(temp.shape, axes):
for slice_list in utilities.get_slice_list(temp.shape, axes):
# don't copy the whole data array
if slice_list == [slice(None, None)]:
inp = temp
else:
inp = temp[slice_list]
inp = self.fft_machine.gfft(
inp,
in_ax=[],
out_ax=[],
ftmachine='fft' if codomain.harmonic else 'ifft',
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=bool(
codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False
)
return_val[slice_list] = inp
else:
return_val = self.fft_machine.gfft(
temp,
inp = self.fft_machine.gfft(
inp,
in_ax=[],
out_ax=[],
ftmachine='fft' if codomain.harmonic else 'ifft',
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=bool(
codomain.paradict['complexity']
),
enforce_hermitian_symmetry=
bool(codomain.paradict['complexity']),
W=-1,
alpha=-1,
verbose=False
)
return_val[slice_list] = inp
if isinstance(val, distributed_data_object):
new_val = val.copy_empty(dtype=codomain.dtype)
new_val.set_full_data(return_val)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment