Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
ada34dff
Commit
ada34dff
authored
Feb 13, 2015
by
ultimanet
Browse files
Fixed pickle compatibility of the fft_factory
parent
34fad519
Changes
1
Hide whitespace changes
Inline
Side-by-side
rg/fft_rg.py
View file @
ada34dff
...
...
@@ -29,227 +29,155 @@ The fft objects must get 3 parameters:
The rg_space into which the field is transformed
'''
def
fft_factory
():
class
fft
(
object
):
def
transform
(
self
,
field_val
,
domain
,
codomain
,
**
kwargs
):
return
None
def
fft_factory
():
if
fft_machine
==
'pyfftw'
:
return
fft_fftw
()
elif
fft_machine
==
'gfft'
or
'gfft_fallback'
:
return
fft_gfft
()
class
fft_fftw
(
fft
):
## The instances of plan_and_info store the fftw plan and all
## other information needed in order to perform a mpi-fftw transformation
class
plan_and_info
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
fft_fftw_context
):
self
.
compute_plan_and_info
(
domain
,
codomain
,
fft_fftw_context
)
def
set_plan
(
self
,
x
):
self
.
plan
=
x
def
get_plan
(
self
):
return
self
.
plan
def
set_domain_centering_mask
(
self
,
x
):
self
.
domain_centering_mask
=
x
def
get_domain_centering_mask
(
self
):
return
self
.
domain_centering_mask
def
set_codomain_centering_mask
(
self
,
x
):
self
.
codomain_centering_mask
=
x
def
get_codomain_centering_mask
(
self
):
return
self
.
codomain_centering_mask
def
compute_plan_and_info
(
self
,
domain
,
codomain
,
fft_fftw_context
):
self
.
input_dtype
=
'complex128'
self
.
output_dtype
=
'complex128'
self
.
global_input_shape
=
domain
.
dim
(
split
=
True
)
self
.
global_output_shape
=
codomain
.
dim
(
split
=
True
)
self
.
fftw_local_size
=
pyfftw
.
local_size
(
self
.
global_input_shape
)
self
.
in_zero_centered_dimensions
=
domain
.
zerocenter
()[::
-
1
]
self
.
out_zero_centered_dimensions
=
codomain
.
zerocenter
()[::
-
1
]
self
.
local_node_dimensions
=
np
.
append
((
self
.
fftw_local_size
[
1
],),
self
.
global_input_shape
[
1
:])
self
.
offsetQ
=
self
.
fftw_local_size
[
2
]
%
2
if
codomain
.
fourier
==
True
:
self
.
direction
=
'FFTW_FORWARD'
else
:
self
.
direction
=
'FFTW_BACKWARD'
##compute the centering masks
self
.
set_domain_centering_mask
(
fft_fftw_context
.
get_centering_mask
(
self
.
in_zero_centered_dimensions
,
self
.
local_node_dimensions
,
self
.
offsetQ
))
self
.
set_codomain_centering_mask
(
fft_fftw_context
.
get_centering_mask
(
self
.
out_zero_centered_dimensions
,
self
.
local_node_dimensions
,
self
.
offsetQ
))
self
.
set_plan
(
pyfftw
.
create_mpi_plan
(
input_shape
=
self
.
global_input_shape
,
input_dtype
=
self
.
input_dtype
,
output_dtype
=
self
.
output_dtype
,
direction
=
self
.
direction
)
##,flags=["FFTW_ESTIMATE"]))
)
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict
=
{}
def
get_centering_mask
(
self
,
to_center_input
,
dimensions_input
,
offset_input
=
0
):
## cast input
to_center
=
np
.
array
(
to_center_input
)
dimensions
=
np
.
array
(
dimensions_input
)
## cast the offset_input into the shape of to_center
offset
=
np
.
zeros
(
to_center
.
shape
,
dtype
=
int
)
offset
[
0
]
=
int
(
offset_input
)
## check for dimension match
if
to_center
.
size
!=
dimensions
.
size
:
raise
TypeError
(
'The length of the supplied lists does not match'
)
## check that every dimension is larger than 1
if
np
.
any
(
dimensions
==
1
):
return
TypeError
(
'Every dimensions must have an extent greater than 1'
)
## build up the value memory
## compute an identifier for the parameter set
temp_id
=
tuple
((
tuple
(
to_center
),
tuple
(
dimensions
),
tuple
(
offset
)))
if
not
temp_id
in
self
.
centering_mask_dict
:
## use np.tile in order to stack the core alternation scheme
## until the desired format is constructed.
core
=
np
.
fromfunction
(
lambda
*
args
:
(
-
1
)
**
(
np
.
tensordot
(
to_center
,
args
+
offset
.
reshape
(
offset
.
shape
+
(
1
,)
*
(
np
.
array
(
args
).
ndim
-
1
)),
1
)),
(
2
,)
*
to_center
.
size
)
centering_mask
=
np
.
tile
(
core
,
dimensions
//
2
)
## for the dimensions of odd size corresponding slices must be added
for
i
in
range
(
centering_mask
.
ndim
):
## check if the size of the certain dimension is odd or even
if
(
dimensions
%
2
)[
i
]
==
0
:
continue
## prepare the slice object
temp_slice
=
(
slice
(
None
),)
*
i
+
(
slice
(
-
2
,
-
1
,
1
),)
+
(
slice
(
None
),)
*
(
centering_mask
.
ndim
-
1
-
i
)
## append the slice to the centering_mask
centering_mask
=
np
.
append
(
centering_mask
,
centering_mask
[
temp_slice
],
axis
=
i
)
self
.
centering_mask_dict
[
temp_id
]
=
centering_mask
return
self
.
centering_mask_dict
[
temp_id
]
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
plan_dict
=
{}
def
get_plan_and_info
(
self
,
domain
,
codomain
):
## generate a id-tuple which identifies the domain-codomain setting
temp_id
=
(
domain
.
__identifier__
(),
codomain
.
__identifier__
())
## generate the plan_and_info object if not already there
if
not
temp_id
in
self
.
plan_dict
:
self
.
plan_dict
[
temp_id
]
=
self
.
plan_and_info
(
domain
,
codomain
,
self
)
return
self
.
plan_dict
[
temp_id
]
def
transform
(
self
,
field_val
,
domain
,
codomain
):
current_plan_and_info
=
self
.
get_plan_and_info
(
domain
,
codomain
)
print
current_plan_and_info
.
get_domain_centering_mask
()
print
current_plan_and_info
.
get_codomain_centering_mask
()
## Prepare the input data
field_val
*=
current_plan_and_info
.
get_codomain_centering_mask
()
## Define a abbreviation for the fftw plan
p
=
current_plan_and_info
.
get_plan
()
## load the field into the plan
if
p
.
has_input
:
p
.
input_array
[:]
=
field_val
## execute the plan
p
()
return
p
.
output_array
*
current_plan_and_info
.
get_domain_centering_mask
()
return
fft_fftw
()
class
fft
(
object
):
def
transform
(
self
,
field_val
,
domain
,
codomain
,
**
kwargs
):
return
None
if
fft_machine
==
'pyfftw'
:
## The instances of plan_and_info store the fftw plan and all
## other information needed in order to perform a mpi-fftw transformation
class
_fftw_plan_and_info
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
fft_fftw_context
):
self
.
compute_plan_and_info
(
domain
,
codomain
,
fft_fftw_context
)
def
set_plan
(
self
,
x
):
self
.
plan
=
x
def
get_plan
(
self
):
return
self
.
plan
elif
fft_machine
==
'gfft'
or
'gfft_fallback'
:
class
fft_gfft
(
fft
):
def
transform
(
self
,
field_val
,
domain
,
codomain
):
naxes
=
(
np
.
size
(
domain
.
para
)
-
1
)
//
2
if
(
codomain
.
fourier
):
ftmachine
=
"fft"
else
:
ftmachine
=
"ifft"
## transform and return
if
(
domain
.
datatype
==
np
.
float64
):
return
gfft
.
gfft
(
field_val
.
astype
(
np
.
complex128
),
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
else
:
return
gfft
.
gfft
(
field_val
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
def
set_domain_centering_mask
(
self
,
x
):
self
.
domain_centering_mask
=
x
def
get_domain_centering_mask
(
self
):
return
self
.
domain_centering_mask
def
set_codomain_centering_mask
(
self
,
x
):
self
.
codomain_centering_mask
=
x
def
get_codomain_centering_mask
(
self
):
return
self
.
codomain_centering_mask
return
fft_gfft
()
def
compute_plan_and_info
(
self
,
domain
,
codomain
,
fft_fftw_context
):
self
.
input_dtype
=
'complex128'
self
.
output_dtype
=
'complex128'
self
.
global_input_shape
=
domain
.
dim
(
split
=
True
)
self
.
global_output_shape
=
codomain
.
dim
(
split
=
True
)
self
.
fftw_local_size
=
pyfftw
.
local_size
(
self
.
global_input_shape
)
self
.
in_zero_centered_dimensions
=
domain
.
zerocenter
()[::
-
1
]
self
.
out_zero_centered_dimensions
=
codomain
.
zerocenter
()[::
-
1
]
'''
def fftw(self,y):
## setting up the environment variables
## input_shape
input_shape=self.para[:self.naxes()]
## datatypes
if(True):#self.datatype==(np.float64 or np.complex128)):
input_dtype='complex128'
output_dtype='complex128'
else:
print self.datatype
raise ValueError(about._errors.cstring("ERROR: space has unsupported datatype != float64."))
## zero_centers
in_zero_center = np.abs(self.para[-naxes:])
out_zero_center = np.abs(codomain.para[-naxes:])
## calculate the inversion masks for the (non-)zero-centered cases
self
.
local_node_dimensions
=
np
.
append
((
self
.
fftw_local_size
[
1
],),
self
.
global_input_shape
[
1
:])
self
.
offsetQ
=
self
.
fftw_local_size
[
2
]
%
2
if
codomain
.
fourier
==
True
:
self
.
direction
=
'FFTW_FORWARD'
else
:
self
.
direction
=
'FFTW_BACKWARD'
##compute the centering masks
self
.
set_domain_centering_mask
(
fft_fftw_context
.
get_centering_mask
(
self
.
in_zero_centered_dimensions
,
self
.
local_node_dimensions
,
self
.
offsetQ
))
self
.
set_codomain_centering_mask
(
fft_fftw_context
.
get_centering_mask
(
self
.
out_zero_centered_dimensions
,
self
.
local_node_dimensions
,
self
.
offsetQ
))
self
.
set_plan
(
pyfftw
.
create_mpi_plan
(
input_shape
=
self
.
global_input_shape
,
input_dtype
=
self
.
input_dtype
,
output_dtype
=
self
.
output_dtype
,
direction
=
self
.
direction
,
flags
=
[
"FFTW_ESTIMATE"
])
)
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
## Setting up the MPI plan
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=["FFTW_ESTIMATE"])
## load the field into the plan
if p.has_input:
p.input_array[:] = y[p.input_slice.start:p.input_slice.stop]
## execute the plan
p()
localTx=p.output_array*post_inverty[p.output_slice.start:p.output_slice.stop]
tempTx=np.empty(x.shape,dtype=output_dtype)
MPI.COMM_WORLD.Allgather([localTx, MPI.COMPLEX],[tempTx, MPI.COMPLEX])
return tempTx
'''
'''
def update_zerocenter_mask(self):
onOffList=np.array(self.para[-naxes:],dtype=bool)
dim=np.array(self.para[:self.naxes()],dtype=int)
os=0
# check if the length of onOffList equals the number of supplied coordinates
if list.size != dim.size:
class
fft_fftw
(
fft
):
## initialize the dictionary which stores the values from get_centering_mask
centering_mask_dict
=
{}
def
get_centering_mask
(
self
,
to_center_input
,
dimensions_input
,
offset_input
=
0
):
## cast input
to_center
=
np
.
array
(
to_center_input
)
dimensions
=
np
.
array
(
dimensions_input
)
## cast the offset_input into the shape of to_center
offset
=
np
.
zeros
(
to_center
.
shape
,
dtype
=
int
)
offset
[
0
]
=
int
(
offset_input
)
## check for dimension match
if
to_center
.
size
!=
dimensions
.
size
:
raise
TypeError
(
'The length of the supplied lists does not match'
)
## check that every dimension is larger than 1
if
np
.
any
(
dimensions
==
1
):
return
TypeError
(
'Every dimensions must have an extent greater than 1'
)
## build up the value memory
## compute an identifier for the parameter set
temp_id
=
tuple
((
tuple
(
to_center
),
tuple
(
dimensions
),
tuple
(
offset
)))
if
not
temp_id
in
self
.
centering_mask_dict
:
## use np.tile in order to stack the core alternation scheme
## until the desired format is constructed.
core
=
np
.
fromfunction
(
lambda
*
args
:
(
-
1
)
**
(
np
.
tensordot
(
to_center
,
args
+
offset
.
reshape
(
offset
.
shape
+
(
1
,)
*
(
np
.
array
(
args
).
ndim
-
1
)),
1
)),
(
2
,)
*
to_center
.
size
)
centering_mask
=
np
.
tile
(
core
,
dimensions
//
2
)
## for the dimensions of odd size corresponding slices must be added
for
i
in
range
(
centering_mask
.
ndim
):
## check if the size of the certain dimension is odd or even
if
(
dimensions
%
2
)[
i
]
==
0
:
continue
## prepare the slice object
temp_slice
=
(
slice
(
None
),)
*
i
+
(
slice
(
-
2
,
-
1
,
1
),)
+
(
slice
(
None
),)
*
(
centering_mask
.
ndim
-
1
-
i
)
## append the slice to the centering_mask
centering_mask
=
np
.
append
(
centering_mask
,
centering_mask
[
temp_slice
],
axis
=
i
)
self
.
centering_mask_dict
[
temp_id
]
=
centering_mask
return
self
.
centering_mask_dict
[
temp_id
]
inverty=np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
return inverty
## The plan_dict stores the plan_and_info objects which correspond
## to a certain set of (field_val, domain, codomain) sets.
plan_dict
=
{}
def
get_plan_and_info
(
self
,
domain
,
codomain
):
## generate a id-tuple which identifies the domain-codomain setting
temp_id
=
(
domain
.
__identifier__
(),
codomain
.
__identifier__
())
## generate the plan_and_info object if not already there
if
not
temp_id
in
self
.
plan_dict
:
self
.
plan_dict
[
temp_id
]
=
_fftw_plan_and_info
(
domain
,
codomain
,
self
)
return
self
.
plan_dict
[
temp_id
]
def
transform
(
self
,
field_val
,
domain
,
codomain
):
current_plan_and_info
=
self
.
get_plan_and_info
(
domain
,
codomain
)
## Prepare the input data
field_val
*=
current_plan_and_info
.
get_codomain_centering_mask
()
## Define a abbreviation for the fftw plan
p
=
current_plan_and_info
.
get_plan
()
## load the field into the plan
if
p
.
has_input
:
p
.
input_array
[:]
=
field_val
## execute the plan
p
()
return
p
.
output_array
*
current_plan_and_info
.
get_domain_centering_mask
()
'''
'''
## TODO: Compute the inverties only for the specific slice!
## TODO: make the inverties a variable of the classes instance!
pre_inverty = getInvertionMask(out_zero_center,input_shape)
##TODO: Does the transformed rg_space ALWAYS have the same shape?
post_inverty = getInvertionMask(in_zero_center,input_shape)
## Prepare the input array
y*=pre_inverty
p = create_mpi_plan(input_shape=input_shape, input_dtype=input_dtype, output_dtype=output_dtype,direction=direction)#,flags=["FFTW_ESTIMATE"])
'''
#inverty = np.fromfunction(lambda *args : (-1)**(np.tensordot(list,args+os.reshape(os.shape+(1,)*(np.array(args).ndim-1)),1)) , dim)
\ No newline at end of file
elif
fft_machine
==
'gfft'
or
'gfft_fallback'
:
class
fft_gfft
(
fft
):
def
transform
(
self
,
field_val
,
domain
,
codomain
):
naxes
=
(
np
.
size
(
domain
.
para
)
-
1
)
//
2
if
(
codomain
.
fourier
):
ftmachine
=
"fft"
else
:
ftmachine
=
"ifft"
## transform and return
if
(
domain
.
datatype
==
np
.
float64
):
return
gfft
.
gfft
(
field_val
.
astype
(
np
.
complex128
),
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
else
:
return
gfft
.
gfft
(
field_val
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment