Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
ea51586f
Commit
ea51586f
authored
May 19, 2016
by
Jait Dixit
Browse files
Add implementation for d2o when axes is not None
parent
6e4dbc02
Pipeline
#3370
skipped
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
rg/nifty_fft.py
View file @
ea51586f
...
...
@@ -235,6 +235,7 @@ class FFTW(FFT):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# TODO +/- magic in both cases
# If the input is a numpy array we transform it locally
if
not
isinstance
(
val
,
distributed_data_object
):
# Copy data for local manipulation
...
...
@@ -251,7 +252,7 @@ class FFTW(FFT):
)
# Apply codomain centering mask
for
slice_list
in
utilities
.
get_slice_list
(
val
.
shape
,
axes
):
for
slice_list
in
utilities
.
get_slice_list
(
local_
val
.
shape
,
axes
):
if
slice_list
==
[
slice
(
None
,
None
)]:
local_val
*=
codomain_centering_mask
else
:
...
...
@@ -269,7 +270,7 @@ class FFTW(FFT):
axes
=
axes
)
# Apply domain centering mask
for
slice_list
in
utilities
.
get_slice_list
(
val
.
shape
,
axes
):
for
slice_list
in
utilities
.
get_slice_list
(
local_
val
.
shape
,
axes
):
if
slice_list
==
[
slice
(
None
,
None
)]:
return_val
*=
domain_centering_mask
else
:
...
...
@@ -277,17 +278,84 @@ class FFTW(FFT):
return
return_val
.
astype
(
codomain
.
dtype
)
else
:
# Setup the final result array
return_val
=
val
.
copy_empty
(
global_shape
=
codomain
.
get_shape
(),
dtype
=
codomain
.
type
)
if
val
.
distribution_strategy
==
'not'
:
pass
new_val
=
val
.
copy
(
distribution_strategy
=
'fftw'
)
return_val
=
self
.
transform
(
new_val
,
domain
,
codomain
,
axes
,
**
kwargs
).
copy
(
distribution_strategy
=
'not'
)
elif
val
.
distribution_strategy
in
(
'equal'
,
'fftw'
,
'freeform'
):
# slicing distributor need to examine axes before proceeding
pass
if
axes
:
# We use pyfftw in this case
# Setup up the array which will be returned
return_val
=
val
.
copy_empty
(
global_shape
=
domain
.
get_shape
(),
dtype
=
codomain
.
type
)
# Find which part of the data resides on this node
local_size
=
pyfftw
.
local_size
(
val
.
shape
)
local_start
=
local_size
[
2
]
local_end
=
local_start
+
local_size
[
1
]
# Extract the relevant data
if
val
.
distribution_strategy
==
'fftw'
:
local_val
=
val
.
get_local_data
()
else
:
local_val
=
val
.
get_data
(
slice
(
local_start
,
local_end
),
local_keys
=
True
).
get_local_data
()
# Create domain and codomain centering mask
domain_centering_mask
=
self
.
get_centering_mask
(
domain
.
paradict
[
'zerocenter'
],
domain
.
get_shape
()
)
codomain_centering_mask
=
self
.
get_centering_mask
(
codomain
.
paradict
[
'zerocenter'
],
codomain
.
get_shape
()
)
# Apply codomain centering mask
for
slice_list
in
utilities
.
get_slice_list
(
local_val
.
shape
,
axes
):
local_val
[
slice_list
]
*=
codomain_centering_mask
if
codomain
.
harmonic
:
result
=
pyfftw
.
interfaces
.
numpy_fft
.
fftn
(
local_val
,
axes
=
axes
)
else
:
result
=
pyfftw
.
interfaces
.
numpy_fft
.
ifftn
(
local_val
,
axes
=
axes
)
# Apply domain centering mask
for
slice_list
in
utilities
.
get_slice_list
(
local_val
.
shape
,
axes
):
result
[
slice_list
]
*=
domain_centering_mask
# Push data in-place in the array to be returned
if
return_val
.
distribution_strategy
==
'fftw'
:
return_val
.
set_local_data
(
result
,
copy
=
False
)
else
:
return_val
.
set_data
(
data
=
result
,
to_key
=
slice
(
local_start
,
local_end
),
local_keys
=
True
)
if
domain
.
paradict
[
'complexity'
]
==
0
:
return_val
.
hermitian
=
True
elif
not
axes
or
axes
==
(
0
,):
# We use pyfftw-mpi in this case
pass
else
:
raise
ValueError
(
'ERROR: Unknown distribution strategy'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment