Commit ada34dff authored by ultimanet's avatar ultimanet
Browse files

Fixed pickle compatibility of the fft_factory

parent 34fad519
......@@ -30,17 +30,22 @@ The fft objects must get 3 parameters:
'''
def fft_factory():
if fft_machine == 'pyfftw':
return fft_fftw()
elif fft_machine == 'gfft' or 'gfft_fallback':
return fft_gfft()
class fft(object):
class fft(object):
def transform(self,field_val,domain,codomain,**kwargs):
return None
if fft_machine == 'pyfftw':
class fft_fftw(fft):
if fft_machine == 'pyfftw':
## The instances of plan_and_info store the fftw plan and all
## other information needed in order to perform a mpi-fftw transformation
class plan_and_info(object):
class _fftw_plan_and_info(object):
def __init__(self,domain,codomain,fft_fftw_context):
self.compute_plan_and_info(domain,codomain,fft_fftw_context)
......@@ -97,9 +102,11 @@ def fft_factory():
input_shape=self.global_input_shape,
input_dtype=self.input_dtype,
output_dtype=self.output_dtype,
direction=self.direction)
##,flags=["FFTW_ESTIMATE"]))
direction=self.direction,
flags=["FFTW_ESTIMATE"])
)
class fft_fftw(fft):
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict = {}
def get_centering_mask(self, to_center_input, dimensions_input, offset_input=0):
......@@ -143,13 +150,11 @@ def fft_factory():
temp_id = (domain.__identifier__(), codomain.__identifier__())
## generate the plan_and_info object if not already there
if not temp_id in self.plan_dict:
self.plan_dict[temp_id]=self.plan_and_info(domain,codomain,self)
self.plan_dict[temp_id]=_fftw_plan_and_info(domain,codomain,self)
return self.plan_dict[temp_id]
def transform(self,field_val,domain,codomain):
current_plan_and_info=self.get_plan_and_info(domain,codomain)
print current_plan_and_info.get_domain_centering_mask()
print current_plan_and_info.get_codomain_centering_mask()
## Prepare the input data
field_val*=current_plan_and_info.get_codomain_centering_mask()
## Define a abbreviation for the fftw plan
......@@ -161,12 +166,8 @@ def fft_factory():
p()
return p.output_array*current_plan_and_info.get_domain_centering_mask()
return fft_fftw()
elif fft_machine == 'gfft' or 'gfft_fallback':
elif fft_machine == 'gfft' or 'gfft_fallback':
class fft_gfft(fft):
def transform(self,field_val,domain,codomain):
naxes = (np.size(domain.para)-1)//2
......@@ -180,76 +181,3 @@ def fft_factory():
else:
return gfft.gfft(field_val,in_ax=[],out_ax=[],ftmachine=ftmachine,in_zero_center=domain.para[-naxes:].astype(np.bool).tolist(),out_zero_center=codomain.para[-naxes:].astype(np.bool).tolist(),enforce_hermitian_symmetry=bool(codomain.para[naxes]==1),W=-1,alpha=-1,verbose=False)
\ No newline at end of file
return fft_gfft()
'''
def fftw(self,y):
## setting up the environment variables
## input_shape
input_shape=self.para[:self.naxes()]
## datatypes
if(True):#self.datatype==(np.float64 or np.complex128)):
input_dtype='complex128'
output_dtype='complex128'
else:
print self.datatype
raise ValueError(about._errors.cstring("ERROR: space has unsupported datatype != float64."))
## zero_centers
in_zero_center = np.abs(self.para[-naxes:])
out_zero_center = np.abs(codomain.para[-naxes:])
## calculate the inversion masks for the (non-)zero-centered cases
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
## Setting up the MPI plan
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=["FFTW_ESTIMATE"])
## load the field into the plan
if p.has_input:
p.input_array[:] = y[p.input_slice.start:p.input_slice.stop]
## execute the plan
p()
localTx=p.output_array*post_inverty[p.output_slice.start:p.output_slice.stop]
tempTx=np.empty(x.shape,dtype=output_dtype)
MPI.COMM_WORLD.Allgather([localTx, MPI.COMPLEX],[tempTx, MPI.COMPLEX])
return tempTx
'''
'''
def update_zerocenter_mask(self):
onOffList=np.array(self.para[-naxes:],dtype=bool)
dim=np.array(self.para[:self.naxes()],dtype=int)
os=0
# check if the length of onOffList equals the number of supplied coordinates
if list.size != dim.size:
raise TypeError('The length of the supplied lists does not match')
inverty=np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
return inverty
'''
'''
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=["FFTW_ESTIMATE"])
'''
#inverty = np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment