Commit dad0df03 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: More housekeeping

- Rename FFT classes so that they follow PEP8 naming conventions. The following classes have been renamed:
    - fft -> FFT
    - fft_fftw -> FFTW
    - fft_gfft -> GFFT
- Add docstrings for new arguments and for the old ones missing one
- Add axes keyword to rg_space.calc_transform()
- More PEP8 fixes in nifty_fft
- GFFT's transform method now uses paradict instead of para
parent 762f85d7
Pipeline #2724 skipped
......@@ -32,30 +32,27 @@ def fft_factory(fft_module_name):
Parameters
----------
None
fft_module_name : String
Select an FFT module
Returns
-----
fft: Returns a fft_object depending on the available packages.
fft : Returns a fft_object depending on the available packages.
Hierarchy: pyfftw -> gfft -> built in gfft.
"""
if fft_module_name == 'pyfftw':
return fft_fftw()
return FFTW()
elif fft_module_name == 'gfft' or 'gfft_dummy':
return fft_gfft(fft_module_name)
return GFFT(fft_module_name)
else:
raise ValueError('Given fft_module_name not known: ' +
str(fft_module_name))
class fft(object):
class FFT(object):
"""
A generic fft object without any implementation.
Parameters
----------
None
"""
def __init__(self):
......@@ -80,14 +77,9 @@ class fft(object):
return None
class fft_fftw(fft):
class FFTW(FFT):
"""
The pyfftw pendant of a fft object.
Parameters
----------
None
"""
def __init__(self):
......@@ -168,14 +160,17 @@ class fft_fftw(fft):
# use np.tile in order to stack the core alternation scheme
# until the desired format is constructed.
core = np.fromfunction(
lambda *args: (-1) **
(np.tensordot(to_center,
args +
offset.reshape(offset.shape +
(1,) *
(np.array(args).ndim - 1)),
1)),
(2,) * to_center.size)
lambda *args: (-1) ** (
np.tensordot(
to_center,
args + offset.reshape(
offset.shape + (1,) * (np.array(args).ndim - 1)
),
1
)
),
(2,) * to_center.size
)
# Cast the core to the smallest integers we can get
core = core.astype(np.int8)
......@@ -197,7 +192,7 @@ class fft_fftw(fft):
# dimension was one
temp_slice = ()
for i in range(len(size_one_dimensions)):
if size_one_dimensions[i] == True:
if size_one_dimensions[i]:
temp_slice += (None,)
else:
temp_slice += (slice(None),)
......@@ -394,13 +389,14 @@ class _fftw_plan_and_info(object):
)
class fft_gfft(fft):
class GFFT(FFT):
"""
The gfft pendant of a fft object.
Parameters
----------
None
fft_module_name : String
Switch between the gfft module used: 'gfft' and 'gfft_dummy'
"""
......@@ -426,6 +422,9 @@ class fft_gfft(fft):
codomain : nifty.rg.nifty_rg.rg_space
The target into which the field should be transformed.
axes : None or tuple
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
......@@ -434,7 +433,6 @@ class fft_gfft(fft):
result : np.ndarray
Fourier-transformed pendant of the input field.
"""
naxes = len(domain.get_shape())
if codomain.harmonic:
ftmachine = "fft"
else:
......@@ -453,30 +451,30 @@ class fft_gfft(fft):
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
in_zero_center=domain.para[-naxes:].
astype(np.bool).tolist(),
out_zero_center=codomain.para[-naxes:].
astype(np.bool).tolist(),
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=bool(
codomain.para[naxes] == 1),
codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False)
verbose=False
)
else:
temp = self.fft_machine.gfft(
temp,
in_ax=[],
out_ax=[],
ftmachine=ftmachine,
in_zero_center=domain.para[-naxes:].
astype(np.bool).tolist(),
out_zero_center=codomain.para[-naxes:].
astype(np.bool).tolist(),
in_zero_center=map(bool, domain.paradict['zerocenter']),
out_zero_center=map(bool, codomain.paradict['zerocenter']),
enforce_hermitian_symmetry=bool(
codomain.para[naxes] == 1),
codomain.paradict['complexity']
),
W=-1,
alpha=-1,
verbose=False)
verbose=False
)
if d2oQ:
new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(temp)
......
......@@ -828,7 +828,7 @@ class rg_space(point_space):
result = np.asscalar(np.real(result))
return result
def calc_transform(self, x, codomain=None, **kwargs):
def calc_transform(self, x, codomain=None, axes=None, **kwargs):
"""
Computes the transform of a given array of field values.
......@@ -839,6 +839,8 @@ class rg_space(point_space):
codomain : nifty.rg_space, *optional*
codomain space to which the transformation shall map
(default: None).
axes : None or tuple
Axes in the array which should be transformed.
Returns
-------
......@@ -861,7 +863,7 @@ class rg_space(point_space):
# Perform the transformation
Tx = self.fft_machine.transform(val=x, domain=self, codomain=codomain,
**kwargs)
axes=axes, **kwargs)
if not codomain.harmonic:
# correct for inverse fft
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment