Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
ada34dff
Commit
ada34dff
authored
Feb 13, 2015
by
ultimanet
Browse files
Fixed pickle compatibility of the fft_factory
parent
34fad519
Changes
1
Show whitespace changes
Inline
Side-by-side
rg/fft_rg.py
View file @
ada34dff
...
...
@@ -30,17 +30,22 @@ The fft objects must get 3 parameters:
'''
def
fft_factory
():
if
fft_machine
==
'pyfftw'
:
return
fft_fftw
()
elif
fft_machine
==
'gfft'
or
'gfft_fallback'
:
return
fft_gfft
()
class
fft
(
object
):
class
fft
(
object
):
def
transform
(
self
,
field_val
,
domain
,
codomain
,
**
kwargs
):
return
None
if
fft_machine
==
'pyfftw'
:
class
fft_fftw
(
fft
):
if
fft_machine
==
'pyfftw'
:
## The instances of plan_and_info store the fftw plan and all
## other information needed in order to perform a mpi-fftw transformation
class
plan_and_info
(
object
):
class
_fftw_
plan_and_info
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
fft_fftw_context
):
self
.
compute_plan_and_info
(
domain
,
codomain
,
fft_fftw_context
)
...
...
@@ -97,9 +102,11 @@ def fft_factory():
input_shape
=
self
.
global_input_shape
,
input_dtype
=
self
.
input_dtype
,
output_dtype
=
self
.
output_dtype
,
direction
=
self
.
direction
)
##,
flags=["FFTW_ESTIMATE"])
)
direction
=
self
.
direction
,
flags
=
[
"FFTW_ESTIMATE"
])
)
class
fft_fftw
(
fft
):
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict
=
{}
def
get_centering_mask
(
self
,
to_center_input
,
dimensions_input
,
offset_input
=
0
):
...
...
@@ -143,13 +150,11 @@ def fft_factory():
temp_id
=
(
domain
.
__identifier__
(),
codomain
.
__identifier__
())
## generate the plan_and_info object if not already there
if
not
temp_id
in
self
.
plan_dict
:
self
.
plan_dict
[
temp_id
]
=
self
.
plan_and_info
(
domain
,
codomain
,
self
)
self
.
plan_dict
[
temp_id
]
=
_fftw_
plan_and_info
(
domain
,
codomain
,
self
)
return
self
.
plan_dict
[
temp_id
]
def
transform
(
self
,
field_val
,
domain
,
codomain
):
current_plan_and_info
=
self
.
get_plan_and_info
(
domain
,
codomain
)
print
current_plan_and_info
.
get_domain_centering_mask
()
print
current_plan_and_info
.
get_codomain_centering_mask
()
## Prepare the input data
field_val
*=
current_plan_and_info
.
get_codomain_centering_mask
()
## Define a abbreviation for the fftw plan
...
...
@@ -161,12 +166,8 @@ def fft_factory():
p
()
return
p
.
output_array
*
current_plan_and_info
.
get_domain_centering_mask
()
return
fft_fftw
()
elif
fft_machine
==
'gfft'
or
'gfft_fallback'
:
elif
fft_machine
==
'gfft'
or
'gfft_fallback'
:
class
fft_gfft
(
fft
):
def
transform
(
self
,
field_val
,
domain
,
codomain
):
naxes
=
(
np
.
size
(
domain
.
para
)
-
1
)
//
2
...
...
@@ -180,76 +181,3 @@ def fft_factory():
else
:
return
gfft
.
gfft
(
field_val
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
\ No newline at end of file
return
fft_gfft
()
'''
def fftw(self,y):
## setting up the environment variables
## input_shape
input_shape=self.para[:self.naxes()]
## datatypes
if(True):#self.datatype==(np.float64 or np.complex128)):
input_dtype='complex128'
output_dtype='complex128'
else:
print self.datatype
raise ValueError(about._errors.cstring("ERROR: space has unsupported datatype != float64."))
## zero_centers
in_zero_center = np.abs(self.para[-naxes:])
out_zero_center = np.abs(codomain.para[-naxes:])
## calculate the inversion masks for the (non-)zero-centered cases
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
## Setting up the MPI plan
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=["FFTW_ESTIMATE"])
## load the field into the plan
if p.has_input:
p.input_array[:] = y[p.input_slice.start:p.input_slice.stop]
## execute the plan
p()
localTx=p.output_array*post_inverty[p.output_slice.start:p.output_slice.stop]
tempTx=np.empty(x.shape,dtype=output_dtype)
MPI.COMM_WORLD.Allgather([localTx, MPI.COMPLEX],[tempTx, MPI.COMPLEX])
return tempTx
'''
'''
def update_zerocenter_mask(self):
onOffList=np.array(self.para[-naxes:],dtype=bool)
dim=np.array(self.para[:self.naxes()],dtype=int)
os=0
# check if the length of onOffList equals the number of supplied coordinates
if list.size != dim.size:
raise TypeError('The length of the supplied lists does not match')
inverty=np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
return inverty
'''
'''
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=["FFTW_ESTIMATE"])
'''
#inverty = np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment