Commit dad0df03 authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: More housekeeping

- Rename FFT classes so that they follow PEP8 naming conventions. The following classes have been renamed:
    - fft -> FFT
    - fft_fftw -> FFTW
    - fft_gfft -> GFFT
- Add docstrings for new arguments and for the old ones missing one
- Add axes keyword to rg_space.calc_transform()
- More PEP8 fixes in nifty_fft
- GFFT's transform method now uses paradict instead of para
parent 762f85d7
Pipeline #2724 skipped
...@@ -32,30 +32,27 @@ def fft_factory(fft_module_name): ...@@ -32,30 +32,27 @@ def fft_factory(fft_module_name):
Parameters Parameters
---------- ----------
None fft_module_name : String
Select an FFT module
Returns Returns
----- -----
fft: Returns a fft_object depending on the available packages. fft : Returns a fft_object depending on the available packages.
Hierarchy: pyfftw -> gfft -> built in gfft. Hierarchy: pyfftw -> gfft -> built in gfft.
""" """
if fft_module_name == 'pyfftw': if fft_module_name == 'pyfftw':
return fft_fftw() return FFTW()
elif fft_module_name == 'gfft' or 'gfft_dummy': elif fft_module_name == 'gfft' or 'gfft_dummy':
return fft_gfft(fft_module_name) return GFFT(fft_module_name)
else: else:
raise ValueError('Given fft_module_name not known: ' + raise ValueError('Given fft_module_name not known: ' +
str(fft_module_name)) str(fft_module_name))
class fft(object): class FFT(object):
""" """
A generic fft object without any implementation. A generic fft object without any implementation.
Parameters
----------
None
""" """
def __init__(self): def __init__(self):
...@@ -80,14 +77,9 @@ class fft(object): ...@@ -80,14 +77,9 @@ class fft(object):
return None return None
class fft_fftw(fft): class FFTW(FFT):
""" """
The pyfftw pendant of a fft object. The pyfftw pendant of a fft object.
Parameters
----------
None
""" """
def __init__(self): def __init__(self):
...@@ -168,14 +160,17 @@ class fft_fftw(fft): ...@@ -168,14 +160,17 @@ class fft_fftw(fft):
# use np.tile in order to stack the core alternation scheme # use np.tile in order to stack the core alternation scheme
# until the desired format is constructed. # until the desired format is constructed.
core = np.fromfunction( core = np.fromfunction(
lambda *args: (-1) ** lambda *args: (-1) ** (
(np.tensordot(to_center, np.tensordot(
args + to_center,
offset.reshape(offset.shape + args + offset.reshape(
(1,) * offset.shape + (1,) * (np.array(args).ndim - 1)
(np.array(args).ndim - 1)), ),
1)), 1
(2,) * to_center.size) )
),
(2,) * to_center.size
)
# Cast the core to the smallest integers we can get # Cast the core to the smallest integers we can get
core = core.astype(np.int8) core = core.astype(np.int8)
...@@ -197,7 +192,7 @@ class fft_fftw(fft): ...@@ -197,7 +192,7 @@ class fft_fftw(fft):
# dimension was one # dimension was one
temp_slice = () temp_slice = ()
for i in range(len(size_one_dimensions)): for i in range(len(size_one_dimensions)):
if size_one_dimensions[i] == True: if size_one_dimensions[i]:
temp_slice += (None,) temp_slice += (None,)
else: else:
temp_slice += (slice(None),) temp_slice += (slice(None),)
...@@ -394,13 +389,14 @@ class _fftw_plan_and_info(object): ...@@ -394,13 +389,14 @@ class _fftw_plan_and_info(object):
) )
class fft_gfft(fft): class GFFT(FFT):
""" """
The gfft pendant of a fft object. The gfft pendant of a fft object.
Parameters Parameters
---------- ----------
None fft_module_name : String
Switch between the gfft module used: 'gfft' and 'gfft_dummy'
""" """
...@@ -426,6 +422,9 @@ class fft_gfft(fft): ...@@ -426,6 +422,9 @@ class fft_gfft(fft):
codomain : nifty.rg.nifty_rg.rg_space codomain : nifty.rg.nifty_rg.rg_space
The target into which the field should be transformed. The target into which the field should be transformed.
axes : None or tuple
The axes which should be transformed.
**kwargs : *optional* **kwargs : *optional*
Further kwargs are not processed. Further kwargs are not processed.
...@@ -434,7 +433,6 @@ class fft_gfft(fft): ...@@ -434,7 +433,6 @@ class fft_gfft(fft):
result : np.ndarray result : np.ndarray
Fourier-transformed pendant of the input field. Fourier-transformed pendant of the input field.
""" """
naxes = len(domain.get_shape())
if codomain.harmonic: if codomain.harmonic:
ftmachine = "fft" ftmachine = "fft"
else: else:
...@@ -453,30 +451,30 @@ class fft_gfft(fft): ...@@ -453,30 +451,30 @@ class fft_gfft(fft):
in_ax=[], in_ax=[],
out_ax=[], out_ax=[],
ftmachine=ftmachine, ftmachine=ftmachine,
in_zero_center=domain.para[-naxes:]. in_zero_center=map(bool, domain.paradict['zerocenter']),
astype(np.bool).tolist(), out_zero_center=map(bool, codomain.paradict['zerocenter']),
out_zero_center=codomain.para[-naxes:].
astype(np.bool).tolist(),
enforce_hermitian_symmetry=bool( enforce_hermitian_symmetry=bool(
codomain.para[naxes] == 1), codomain.paradict['complexity']
),
W=-1, W=-1,
alpha=-1, alpha=-1,
verbose=False) verbose=False
)
else: else:
temp = self.fft_machine.gfft( temp = self.fft_machine.gfft(
temp, temp,
in_ax=[], in_ax=[],
out_ax=[], out_ax=[],
ftmachine=ftmachine, ftmachine=ftmachine,
in_zero_center=domain.para[-naxes:]. in_zero_center=map(bool, domain.paradict['zerocenter']),
astype(np.bool).tolist(), out_zero_center=map(bool, codomain.paradict['zerocenter']),
out_zero_center=codomain.para[-naxes:].
astype(np.bool).tolist(),
enforce_hermitian_symmetry=bool( enforce_hermitian_symmetry=bool(
codomain.para[naxes] == 1), codomain.paradict['complexity']
),
W=-1, W=-1,
alpha=-1, alpha=-1,
verbose=False) verbose=False
)
if d2oQ: if d2oQ:
new_val = val.copy_empty(dtype=np.complex128) new_val = val.copy_empty(dtype=np.complex128)
new_val.set_full_data(temp) new_val.set_full_data(temp)
......
...@@ -828,7 +828,7 @@ class rg_space(point_space): ...@@ -828,7 +828,7 @@ class rg_space(point_space):
result = np.asscalar(np.real(result)) result = np.asscalar(np.real(result))
return result return result
def calc_transform(self, x, codomain=None, **kwargs): def calc_transform(self, x, codomain=None, axes=None, **kwargs):
""" """
Computes the transform of a given array of field values. Computes the transform of a given array of field values.
...@@ -839,6 +839,8 @@ class rg_space(point_space): ...@@ -839,6 +839,8 @@ class rg_space(point_space):
codomain : nifty.rg_space, *optional* codomain : nifty.rg_space, *optional*
codomain space to which the transformation shall map codomain space to which the transformation shall map
(default: None). (default: None).
axes : None or tuple
Axes in the array which should be transformed.
Returns Returns
------- -------
...@@ -861,7 +863,7 @@ class rg_space(point_space): ...@@ -861,7 +863,7 @@ class rg_space(point_space):
# Perform the transformation # Perform the transformation
Tx = self.fft_machine.transform(val=x, domain=self, codomain=codomain, Tx = self.fft_machine.transform(val=x, domain=self, codomain=codomain,
**kwargs) axes=axes, **kwargs)
if not codomain.harmonic: if not codomain.harmonic:
# correct for inverse fft # correct for inverse fft
...@@ -1675,4 +1677,4 @@ class rg_space(point_space): ...@@ -1675,4 +1677,4 @@ class rg_space(point_space):
def __repr__(self): def __repr__(self):
string = super(rg_space, self).__repr__() string = super(rg_space, self).__repr__()
string += repr(self.fft_machine) + "\n " string += repr(self.fft_machine) + "\n "
return string return string
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment