Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
c9ae0203
Commit
c9ae0203
authored
Jun 13, 2016
by
Jait Dixit
Browse files
Move individual cases in FFTW's transform to separate methods
parent
afa17f5a
Pipeline
#5049
skipped
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
rg/nifty_fft.py
View file @
c9ae0203
...
...
@@ -299,6 +299,109 @@ class FFTW(FFT):
return
result
def
_not_slicing_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
temp_val
=
val
.
copy_empty
(
distribution_strategy
=
'fftw'
)
about
.
warnings
.
cprint
(
'WARNING: Repacking d2o to fftw
\
distribution strategy'
)
temp_val
.
set_full_data
(
val
,
copy
=
False
)
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
result
=
self
.
transform
(
temp_val
,
domain
,
codomain
,
axes
,
**
kwargs
)
return_val
=
val
.
copy_empty
(
distribution_strategy
=
val
.
distribution_strategy
)
return_val
.
set_full_data
(
result
,
copy
=
False
)
return
return_val
def
_slicing_local_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
is_local
=
True
,
**
kwargs
)
# Compute transform for the local data
result
=
self
.
_local_transform
(
val
.
get_local_data
(
copy
=
False
),
current_info
,
axes
,
domain
,
codomain
)
# Create return object and insert results inplace
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
codomain
.
dtype
)
return_val
.
set_local_data
(
data
=
result
,
copy
=
False
)
return
return_val
def
_slicing_not_fftw_mpi_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
temp_val
=
val
.
copy_empty
(
distribution_strategy
=
'fftw'
)
temp_val
.
set_full_data
(
val
,
copy
=
False
)
# Recursive call to transform
result
=
self
.
transform
(
temp_val
,
domain
,
codomain
,
axes
,
**
kwargs
)
return_val
=
result
.
copy_empty
(
distribution_strategy
=
val
.
distribution_strategy
)
return_val
.
set_full_data
(
data
=
result
,
copy
=
False
)
return
return_val
def
_slicing_fftw_mpi_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
**
kwargs
)
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
codomain
.
dtype
)
# Extract local data
local_val
=
val
.
get_local_data
(
copy
=
False
)
# Create temporary storage for slices
temp_val
=
None
# If axes tuple includes all axes, set it to None
if
axes
is
not
None
:
if
set
(
axes
)
==
set
(
range
(
len
(
val
.
shape
))):
axes
=
None
for
slice_list
in
utilities
.
get_slice_list
(
local_val
.
shape
,
axes
):
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
local_val
else
:
if
temp_val
is
None
:
temp_val
=
np
.
empty_like
(
local_val
)
inp
=
local_val
[
slice_list
]
# This is in order to make FFTW behave properly when slicing input
# over MPI ranks when the input is 1-dimensional. The default
# behaviour is to optimize to take advantage of byte-alignment,
# which doesn't match the slicing strategy for multi-dimensional
# data.
original_shape
=
None
if
len
(
inp
.
shape
)
==
1
:
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
inp
.
shape
[
0
],
1
)
result
=
self
.
_mpi_transform
(
inp
,
current_info
,
axes
,
domain
,
codomain
)
if
slice_list
==
[
slice
(
None
,
None
)]:
temp_val
=
result
else
:
# Reverting to the original shape i.e. before the input was
# augmented with 1 to make FFTW behave properly.
if
original_shape
is
not
None
:
result
=
result
.
reshape
(
original_shape
)
temp_val
[
slice_list
]
=
result
return_val
.
set_local_data
(
data
=
temp_val
,
copy
=
False
)
return
return_val
def
transform
(
self
,
val
,
domain
,
codomain
,
axes
=
None
,
**
kwargs
):
"""
The pyfftw transform function.
...
...
@@ -352,108 +455,22 @@ class FFTW(FFT):
if
axes
is
None
or
set
(
axes
)
==
set
(
range
(
len
(
val
.
shape
)))
\
or
0
in
axes
:
if
val
.
distribution_strategy
!=
'fftw'
:
temp_val
=
val
.
copy_empty
(
distribution_strategy
=
'fftw'
)
temp_val
.
set_full_data
(
val
,
copy
=
False
)
# Recursive call to transform
result
=
self
.
transform
(
temp_val
,
domain
,
codomain
,
axes
,
**
kwargs
)
return_val
=
result
.
copy_empty
(
distribution_strategy
=
val
.
distribution_strategy
)
return_val
.
set_full_data
(
data
=
result
,
copy
=
False
)
return_val
=
\
self
.
_slicing_not_fftw_mpi_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
else
:
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
**
kwargs
)
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
codomain
.
dtype
)
# Extract local data
local_val
=
val
.
get_local_data
(
copy
=
False
)
# Create temporary storage for slices
temp_val
=
None
# If axes tuple includes all axes, set it to None
if
axes
is
not
None
:
if
set
(
axes
)
==
set
(
range
(
len
(
val
.
shape
))):
axes
=
None
for
slice_list
in
\
utilities
.
get_slice_list
(
local_val
.
shape
,
axes
):
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
local_val
else
:
if
temp_val
is
None
:
temp_val
=
np
.
empty_like
(
local_val
)
inp
=
local_val
[
slice_list
]
# This is in order to make FFTW behave properly
# when slicing input over MPI ranks when the
# input is 1-dimensional. The default behaviour
# is to slice so that it's byte-aligned, which
# doesn't play well with multi-dimensional data
# sliced for FFTW.
original_shape
=
None
if
len
(
inp
.
shape
)
==
1
:
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
inp
.
shape
[
0
],
1
)
result
=
self
.
_mpi_transform
(
inp
,
current_info
,
axes
,
domain
,
codomain
)
if
slice_list
==
[
slice
(
None
,
None
)]:
temp_val
=
result
else
:
# Reverting to the original shape i.e. before
# the input was augmented with 1 to make
# FFTW behave properly.
if
original_shape
is
not
None
:
result
=
result
.
reshape
(
original_shape
)
temp_val
[
slice_list
]
=
result
return_val
.
set_local_data
(
data
=
temp_val
,
copy
=
False
)
return_val
=
self
.
_slicing_fftw_mpi_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
else
:
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
is_local
=
True
,
**
kwargs
)
# Compute transform for the local data
result
=
self
.
_local_transform
(
val
.
get_local_data
(
copy
=
False
),
current_info
,
axes
,
domain
,
codomain
return_val
=
self
.
_slicing_local_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
# Create return object and insert results inplace
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
codomain
.
dtype
)
return_val
.
set_local_data
(
data
=
result
,
copy
=
False
)
# If domain is purely real, the result of the FFT is hermitian
if
domain
.
paradict
[
'complexity'
]
==
0
:
return_val
.
hermitian
=
True
else
:
temp_val
=
val
.
copy_empty
(
distribution_strategy
=
'fftw'
)
about
.
warnings
.
cprint
(
'WARNING: Repacking d2o to fftw
\
distribution strategy'
)
temp_val
.
set_full_data
(
val
,
copy
=
False
)
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
result
=
self
.
transform
(
temp_val
,
domain
,
codomain
,
axes
,
**
kwargs
)
return_val
=
val
.
copy_empty
(
distribution_strategy
=
val
.
distribution_strategy
return_val
=
self
.
_not_slicing_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
return_val
.
set_full_data
(
result
,
copy
=
False
)
# If domain is purely real, the result of the FFT is hermitian
if
domain
.
paradict
[
'complexity'
]
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment