Commit ce38c34c authored by ultimanet's avatar ultimanet
Browse files

Added documentation

parent ada34dff
......@@ -16,20 +16,22 @@ except(ImportError):
import gfft_rg as gfft
fft_machine='gfft_fallback'
#about.infos.cprint('INFO: Using builtin "plain" gfft version 0.1.0')
'''
The fft_factory checks which fft module is available and returns the appropriate fft object.
The fft objects must get 3 parameters:
1. field_val:
The value-array of the field which is supposed to be transformed
2. rg_space:
The field's underlying rg_space
3. codaim
The rg_space into which the field is transformed
'''
def fft_factory():
def fft_factory():
"""
A factory for fast-fourier-transformation objects.
Parameters
----------
None
Returns
-----
fft: Returns a fft_object depending on the available packages.
Hierarchy: pyfftw -> gfft -> built in gfft.
"""
if fft_machine == 'pyfftw':
return fft_fftw()
......@@ -39,15 +41,37 @@ def fft_factory():
class fft(object):
"""
A generic fft object without any implementation.
Parameters
----------
None
"""
def transform(self,field_val,domain,codomain,**kwargs):
"""
A generic ff-transform function.
Parameters
----------
field_val : numpy.ndarray
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
"""
return None
if fft_machine == 'pyfftw':
## The instances of plan_and_info store the fftw plan and all
## other information needed in order to perform a mpi-fftw transformation
class _fftw_plan_and_info(object):
def __init__(self,domain,codomain,fft_fftw_context):
self.compute_plan_and_info(domain,codomain,fft_fftw_context)
def __init__(self,domain,codomain,fft_fftw_context,**kwargs):
self.compute_plan_and_info(domain,codomain,fft_fftw_context,**kwargs)
def set_plan(self, x):
self.plan=x
......@@ -64,7 +88,7 @@ if fft_machine == 'pyfftw':
def get_codomain_centering_mask(self):
return self.codomain_centering_mask
def compute_plan_and_info(self, domain, codomain,fft_fftw_context):
def compute_plan_and_info(self, domain, codomain,fft_fftw_context,**kwargs):
self.input_dtype = 'complex128'
self.output_dtype = 'complex128'
......@@ -103,13 +127,44 @@ if fft_machine == 'pyfftw':
input_dtype=self.input_dtype,
output_dtype=self.output_dtype,
direction=self.direction,
flags=["FFTW_ESTIMATE"])
flags=["FFTW_ESTIMATE"],
**kwargs)
)
class fft_fftw(fft):
class fft_fftw(fft):
"""
The pyfftw pendant of a fft object.
Parameters
----------
None
"""
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict = {}
def get_centering_mask(self, to_center_input, dimensions_input, offset_input=0):
"""
Computes the mask, used to (de-)zerocenter domain and target
fields.
Parameters
----------
to_center_input : tuple, list, numpy.ndarray
A tuple of booleans which dimensions should be
zero-centered.
dimensions_input : tuple, list, numpy.ndarray
A tuple containing the masks desired shape.
offset_input : int, boolean
Specifies whether the zero-th dimension starts with an odd
or and even index, i.e. if it is shifted.
Returns
-------
result : np.ndarray
A 1/-1-alternating mask.
"""
## cast input
to_center = np.array(to_center_input)
dimensions = np.array(dimensions_input)
......@@ -118,10 +173,10 @@ if fft_machine == 'pyfftw':
offset[0] = int(offset_input)
## check for dimension match
if to_center.size != dimensions.size:
raise TypeError('The length of the supplied lists does not match')
raise TypeError('The length of the supplied lists does not match.')
## check that every dimension is larger than 1
if np.any(dimensions == 1):
return TypeError('Every dimensions must have an extent greater than 1')
return TypeError('Every dimensions must have an extent greater than 1.')
## build up the value memory
## compute an identifier for the parameter set
temp_id = tuple((tuple(to_center),tuple(dimensions),tuple(offset)))
......@@ -145,16 +200,39 @@ if fft_machine == 'pyfftw':
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
plan_dict = {}
def get_plan_and_info(self,domain,codomain):
def _get_plan_and_info(self,domain,codomain,**kwargs):
## generate a id-tuple which identifies the domain-codomain setting
temp_id = (domain.__identifier__(), codomain.__identifier__())
## generate the plan_and_info object if not already there
if not temp_id in self.plan_dict:
self.plan_dict[temp_id]=_fftw_plan_and_info(domain,codomain,self)
self.plan_dict[temp_id]=_fftw_plan_and_info(domain,codomain,self,**kwargs)
return self.plan_dict[temp_id]
def transform(self,field_val,domain,codomain):
current_plan_and_info=self.get_plan_and_info(domain,codomain)
def transform(self,field_val,domain,codomain,**kwargs):
"""
The pyfftw transform function.
Parameters
----------
field_val : numpy.ndarray
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
**kwargs : *optional*
Further kwargs are passed to the create_mpi_plan routine.
Returns
-------
result : np.ndarray
Fourier-transformed pendant of the input field.
"""
current_plan_and_info=self._get_plan_and_info(domain,codomain,**kwargs)
## Prepare the input data
field_val*=current_plan_and_info.get_codomain_centering_mask()
## Define a abbreviation for the fftw plan
......@@ -169,7 +247,38 @@ if fft_machine == 'pyfftw':
elif fft_machine == 'gfft' or 'gfft_fallback':
class fft_gfft(fft):
def transform(self,field_val,domain,codomain):
"""
The gfft pendant of a fft object.
Parameters
----------
None
"""
def transform(self,field_val,domain,codomain,**kwargs):
"""
The gfft transform function.
Parameters
----------
field_val : numpy.ndarray
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
Returns
-------
result : np.ndarray
Fourier-transformed pendant of the input field.
"""
naxes = (np.size(domain.para)-1)//2
if(codomain.fourier):
ftmachine = "fft"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment