Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
N
NIFTy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Monitor
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
ift
NIFTy
Commits
ada34dff
Commit
ada34dff
authored
10 years ago
by
ultimanet
Browse files
Options
Downloads
Patches
Plain Diff
Fixed pickle compatibility of the fft_factory
parent
34fad519
No related branches found
No related tags found
1 merge request
!3
Nifty 2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
rg/fft_rg.py
+138
-210
138 additions, 210 deletions
rg/fft_rg.py
with
138 additions
and
210 deletions
rg/fft_rg.py
+
138
−
210
View file @
ada34dff
...
@@ -30,17 +30,22 @@ The fft objects must get 3 parameters:
...
@@ -30,17 +30,22 @@ The fft objects must get 3 parameters:
'''
'''
def
fft_factory
():
def
fft_factory
():
if
fft_machine
==
'
pyfftw
'
:
return
fft_fftw
()
elif
fft_machine
==
'
gfft
'
or
'
gfft_fallback
'
:
return
fft_gfft
()
class
fft
(
object
):
class
fft
(
object
):
def
transform
(
self
,
field_val
,
domain
,
codomain
,
**
kwargs
):
def
transform
(
self
,
field_val
,
domain
,
codomain
,
**
kwargs
):
return
None
return
None
if
fft_machine
==
'
pyfftw
'
:
if
fft_machine
==
'
pyfftw
'
:
class
fft_fftw
(
fft
):
## The instances of plan_and_info store the fftw plan and all
## The instances of plan_and_info store the fftw plan and all
## other information needed in order to perform a mpi-fftw transformation
## other information needed in order to perform a mpi-fftw transformation
class
plan_and_info
(
object
):
class
_fftw_
plan_and_info
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
fft_fftw_context
):
def
__init__
(
self
,
domain
,
codomain
,
fft_fftw_context
):
self
.
compute_plan_and_info
(
domain
,
codomain
,
fft_fftw_context
)
self
.
compute_plan_and_info
(
domain
,
codomain
,
fft_fftw_context
)
...
@@ -97,9 +102,11 @@ def fft_factory():
...
@@ -97,9 +102,11 @@ def fft_factory():
input_shape
=
self
.
global_input_shape
,
input_shape
=
self
.
global_input_shape
,
input_dtype
=
self
.
input_dtype
,
input_dtype
=
self
.
input_dtype
,
output_dtype
=
self
.
output_dtype
,
output_dtype
=
self
.
output_dtype
,
direction
=
self
.
direction
)
direction
=
self
.
direction
,
##,
flags=["FFTW_ESTIMATE"])
)
flags
=
[
"
FFTW_ESTIMATE
"
])
)
)
class
fft_fftw
(
fft
):
## initialize the dictionary which stores the values from get_centering_mask
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict
=
{}
centering_mask_dict
=
{}
def
get_centering_mask
(
self
,
to_center_input
,
dimensions_input
,
offset_input
=
0
):
def
get_centering_mask
(
self
,
to_center_input
,
dimensions_input
,
offset_input
=
0
):
...
@@ -143,13 +150,11 @@ def fft_factory():
...
@@ -143,13 +150,11 @@ def fft_factory():
temp_id
=
(
domain
.
__identifier__
(),
codomain
.
__identifier__
())
temp_id
=
(
domain
.
__identifier__
(),
codomain
.
__identifier__
())
## generate the plan_and_info object if not already there
## generate the plan_and_info object if not already there
if
not
temp_id
in
self
.
plan_dict
:
if
not
temp_id
in
self
.
plan_dict
:
self
.
plan_dict
[
temp_id
]
=
self
.
plan_and_info
(
domain
,
codomain
,
self
)
self
.
plan_dict
[
temp_id
]
=
_fftw_
plan_and_info
(
domain
,
codomain
,
self
)
return
self
.
plan_dict
[
temp_id
]
return
self
.
plan_dict
[
temp_id
]
def
transform
(
self
,
field_val
,
domain
,
codomain
):
def
transform
(
self
,
field_val
,
domain
,
codomain
):
current_plan_and_info
=
self
.
get_plan_and_info
(
domain
,
codomain
)
current_plan_and_info
=
self
.
get_plan_and_info
(
domain
,
codomain
)
print
current_plan_and_info
.
get_domain_centering_mask
()
print
current_plan_and_info
.
get_codomain_centering_mask
()
## Prepare the input data
## Prepare the input data
field_val
*=
current_plan_and_info
.
get_codomain_centering_mask
()
field_val
*=
current_plan_and_info
.
get_codomain_centering_mask
()
## Define a abbreviation for the fftw plan
## Define a abbreviation for the fftw plan
...
@@ -161,10 +166,6 @@ def fft_factory():
...
@@ -161,10 +166,6 @@ def fft_factory():
p
()
p
()
return
p
.
output_array
*
current_plan_and_info
.
get_domain_centering_mask
()
return
p
.
output_array
*
current_plan_and_info
.
get_domain_centering_mask
()
return
fft_fftw
()
elif
fft_machine
==
'
gfft
'
or
'
gfft_fallback
'
:
elif
fft_machine
==
'
gfft
'
or
'
gfft_fallback
'
:
class
fft_gfft
(
fft
):
class
fft_gfft
(
fft
):
...
@@ -180,76 +181,3 @@ def fft_factory():
...
@@ -180,76 +181,3 @@ def fft_factory():
else
:
else
:
return
gfft
.
gfft
(
field_val
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
return
gfft
.
gfft
(
field_val
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
\ No newline at end of file
return
fft_gfft
()
'''
def fftw(self,y):
## setting up the environment variables
## input_shape
input_shape=self.para[:self.naxes()]
## datatypes
if(True):#self.datatype==(np.float64 or np.complex128)):
input_dtype=
'
complex128
'
output_dtype=
'
complex128
'
else:
print self.datatype
raise ValueError(about._errors.cstring(
"
ERROR: space has unsupported datatype != float64.
"
))
## zero_centers
in_zero_center = np.abs(self.para[-naxes:])
out_zero_center = np.abs(codomain.para[-naxes:])
## calculate the inversion masks for the (non-)zero-centered cases
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
## Setting up the MPI plan
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=[
"
FFTW_ESTIMATE
"
])
## load the field into the plan
if p.has_input:
p.input_array[:] = y[p.input_slice.start:p.input_slice.stop]
## execute the plan
p()
localTx=p.output_array*post_inverty[p.output_slice.start:p.output_slice.stop]
tempTx=np.empty(x.shape,dtype=output_dtype)
MPI.COMM_WORLD.Allgather([localTx, MPI.COMPLEX],[tempTx, MPI.COMPLEX])
return tempTx
'''
'''
def update_zerocenter_mask(self):
onOffList=np.array(self.para[-naxes:],dtype=bool)
dim=np.array(self.para[:self.naxes()],dtype=int)
os=0
# check if the length of onOffList equals the number of supplied coordinates
if list.size != dim.size:
raise TypeError(
'
The length of the supplied lists does not match
'
)
inverty=np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
return inverty
'''
'''
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=[
"
FFTW_ESTIMATE
"
])
'''
#inverty = np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment