Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
1182e522
Commit
1182e522
authored
Jun 14, 2016
by
theos
Browse files
Fixed handling of centering masks in nifty_fft.py.
Condensed functional structure in nifty_fft.py.
parent
157f4fea
Pipeline
#5114
skipped
Changes
2
Pipelines
1
Show whitespace changes
Inline
Side-by-side
rg/nifty_fft.py
View file @
1182e522
# -*- coding: utf-8 -*-
import
warnings
import
numpy
as
np
from
mpi4py
import
MPI
from
d2o
import
distributed_data_object
,
distributor_factory
,
STRATEGIES
from
nifty.config
import
about
,
dependency_injector
as
gdi
from
d2o
import
distributed_data_object
,
\
STRATEGIES
from
nifty.config
import
about
,
\
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
pyfftw
=
gdi
.
get
(
'pyfftw'
)
...
...
@@ -11,23 +14,6 @@ gfft = gdi.get('gfft')
gfft_dummy
=
gdi
.
get
(
'gfft_dummy'
)
# Try to import pyfftw. If this fails fall back to gfft.
# If this fails fall back to local gfft_rg
# try:
# import pyfftw
# fft_machine='pyfftw'
# except(ImportError):
# try:
# import gfft
# fft_machine='gfft'
# about.infos.cprint('INFO: Using gfft')
# except(ImportError):
# import gfft_rg as gfft
# fft_machine='gfft_fallback'
# about.infos.cprint('INFO: Using builtin "plain" gfft version 0.1.0')
def
fft_factory
(
fft_module_name
):
"""
A factory for fast-fourier-transformation objects.
...
...
@@ -103,7 +89,7 @@ class FFTW(FFT):
pyfftw
.
interfaces
.
cache
.
enable
()
def
get_centering_mask
(
self
,
to_center_input
,
dimensions_input
,
offset_input
=
0
):
offset_input
=
False
):
"""
Computes the mask, used to (de-)zerocenter domain and target
fields.
...
...
@@ -205,8 +191,8 @@ class FFTW(FFT):
self
.
centering_mask_dict
[
temp_id
]
=
centering_mask
return
self
.
centering_mask_dict
[
temp_id
]
def
_get_transform_info
(
self
,
domain
,
codomain
,
local_shape
_info
=
None
,
is_local
=
False
,
**
kwargs
):
def
_get_transform_info
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
is_local
,
**
kwargs
):
# generate a id-tuple which identifies the domain-codomain setting
temp_id
=
domain
.
__hash__
()
^
(
101
*
codomain
.
__hash__
())
...
...
@@ -214,12 +200,12 @@ class FFTW(FFT):
if
temp_id
not
in
self
.
info_dict
:
if
is_local
:
self
.
info_dict
[
temp_id
]
=
FFTWLocalTransformInfo
(
domain
,
codomain
,
local_shape
_info
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
self
,
**
kwargs
)
else
:
self
.
info_dict
[
temp_id
]
=
FFTWMPITransfromInfo
(
domain
,
codomain
,
local_shape
_info
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
self
,
**
kwargs
)
...
...
@@ -254,25 +240,7 @@ class FFTW(FFT):
return
val
*
mask
def
_local_transform
(
self
,
val
,
info
,
axes
,
domain
,
codomain
):
# Apply codomain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
codomain
.
paradict
[
'zerocenter'
]):
temp_val
=
np
.
copy
(
val
)
val
=
self
.
_apply_mask
(
temp_val
,
info
.
cmask_codomain
,
axes
)
result
=
info
.
fftw_interface
(
val
,
axes
=
axes
,
planner_effort
=
'FFTW_ESTIMATE'
)
# Apply domain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
domain
.
paradict
[
'zerocenter'
]):
result
=
self
.
_apply_mask
(
result
,
info
.
cmask_domain
,
axes
)
# Correct the sign if needed
result
*=
info
.
sign
return
result
def
_mpi_transform
(
self
,
val
,
info
,
axes
,
domain
,
codomain
):
def
_atomic_mpi_transform
(
self
,
val
,
info
,
axes
,
domain
,
codomain
):
# Apply codomain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
codomain
.
paradict
[
'zerocenter'
]):
temp_val
=
np
.
copy
(
val
)
...
...
@@ -288,7 +256,7 @@ class FFTW(FFT):
if
p
.
has_output
:
result
=
p
.
output_array
else
:
r
aise
RuntimeError
(
'ERROR: PyFFTW-MPI transform failed.'
)
r
eturn
None
# Apply domain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
domain
.
paradict
[
'zerocenter'
]):
...
...
@@ -299,74 +267,88 @@ class FFTW(FFT):
return
result
def
_not_slicing_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
temp_val
=
val
.
copy_empty
(
distribution_strategy
=
'fftw'
)
about
.
warnings
.
cprint
(
'WARNING: Repacking d2o to fftw
\
distribution strategy'
)
temp_val
.
set_full_data
(
val
,
copy
=
False
)
# Recursive call to take advantage of the fact that the data
# necessary is already present on the nodes.
result
=
self
.
transform
(
temp_val
,
domain
,
codomain
,
axes
,
def
_local_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
####
# val must be numpy array or d2o with slicing distributor
###
local_offset_Q
=
False
try
:
local_val
=
val
.
get_local_data
(
copy
=
False
),
if
axes
is
None
or
0
in
axes
:
local_offset_Q
=
val
.
distributor
.
local_shape
[
0
]
%
2
except
(
AttributeError
):
local_val
=
val
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
local_shape
=
local_val
.
shape
,
local_offset_Q
=
local_offset_Q
,
is_local
=
True
,
**
kwargs
)
return_val
=
val
.
copy_empty
(
distribution_strategy
=
val
.
distribution_strategy
)
return_val
.
set_full_data
(
result
,
copy
=
False
)
# Apply codomain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
codomain
.
paradict
[
'zerocenter'
]):
temp_val
=
np
.
copy
(
local_val
)
local_val
=
self
.
_apply_mask
(
temp_val
,
current_info
.
cmask_codomain
,
axes
)
return
return_val
local_result
=
current_info
.
fftw_interface
(
local_val
,
axes
=
axes
,
planner_effort
=
'FFTW_ESTIMATE'
)
def
_slicing_local_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
is_local
=
True
,
**
kwargs
)
# Apply domain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
domain
.
paradict
[
'zerocenter'
]):
local_result
=
self
.
_apply_mask
(
local_result
,
current_info
.
cmask_domain
,
axes
)
# Compute transform for the local data
result
=
self
.
_local_transform
(
val
.
get_local_data
(
copy
=
False
),
current_info
,
axes
,
domain
,
codomain
)
# Correct the sign if needed
if
current_info
.
sign
!=
1
:
local_result
*=
current_info
.
sign
try
:
# Create return object and insert results inplace
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
codomain
.
dtype
)
return_val
.
set_local_data
(
data
=
result
,
copy
=
False
)
return_val
.
set_local_data
(
data
=
local_result
,
copy
=
False
)
except
(
AttributeError
):
return_val
=
local_result
return
return_val
def
_
slicing_not
_fftw_
mpi
_transform
(
self
,
val
,
domain
,
codomain
,
def
_
repack_to
_fftw_
and
_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
temp_val
=
val
.
copy_empty
(
distribution_strategy
=
'fftw'
)
about
.
warnings
.
cprint
(
'WARNING: Repacking d2o to fftw
\
distribution strategy'
)
temp_val
.
set_full_data
(
val
,
copy
=
False
)
# Recursive call to transform
result
=
self
.
transform
(
temp_val
,
domain
,
codomain
,
axes
,
**
kwargs
)
return_val
=
result
.
copy_empty
(
distribution_strategy
=
val
.
distribution_strategy
)
distribution_strategy
=
val
.
distribution_strategy
)
return_val
.
set_full_data
(
data
=
result
,
copy
=
False
)
return
return_val
def
_get_local_shape_info
(
self
,
comm
,
global_shape
,
distribution_strategy
):
if
distribution_strategy
==
'equal'
:
local_slice
=
distributor_factory
.
_equal_slicer
(
comm
,
global_shape
)
local_shape
=
np
.
append
((
local_slice
[
1
]
-
local_slice
[
0
],),
global_shape
[
1
:])
return
(
local_shape
,
local_slice
[
0
])
def
_slicing_fftw_mpi_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
local_shape_info
=
self
.
_get_local_shape_info
(
val
.
comm
,
val
.
shape
,
val
.
distribution_strategy
),
**
kwargs
)
def
_mpi_transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
if
axes
is
None
or
0
in
axes
:
local_offset_list
=
np
.
cumsum
(
np
.
concatenate
(
[[
0
,
],
val
.
distributor
.
all_local_slices
[:,
2
]]))
local_offset_Q
=
bool
(
local_offset_list
[
val
.
distributor
.
comm
.
rank
]
%
2
)
else
:
local_offset_Q
=
False
current_info
=
self
.
_get_transform_info
(
domain
,
codomain
,
local_shape
=
val
.
local_shape
,
local_offset_Q
=
local_offset_Q
,
is_local
=
False
,
**
kwargs
)
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
codomain
.
dtype
)
...
...
@@ -399,10 +381,14 @@ class FFTW(FFT):
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
inp
.
shape
[
0
],
1
)
result
=
self
.
_mpi_transform
(
inp
,
current_info
,
axes
,
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
result
=
self
.
_atomic_mpi_transform
(
inp
,
current_info
,
axes
,
domain
,
codomain
)
if
slice_list
==
[
slice
(
None
,
None
)]:
if
result
is
None
:
temp_val
=
np
.
empty_like
(
local_val
)
elif
slice_list
==
[
slice
(
None
,
None
)]:
temp_val
=
result
else
:
# Reverting to the original shape i.e. before the input was
...
...
@@ -460,28 +446,23 @@ class FFTW(FFT):
return_val
=
self
.
_local_transform
(
temp_val
,
current_info
,
axes
,
domain
,
codomain
)
else
:
if
val
.
comm
is
not
MPI
.
COMM_WORLD
:
raise
RuntimeError
(
'ERROR: Input array uses an unsupported
\
comm object'
)
if
val
.
distribution_strategy
in
STRATEGIES
[
'slicing'
]:
if
axes
is
None
or
set
(
axes
)
==
set
(
range
(
len
(
val
.
shape
)))
\
or
0
in
axes
:
if
axes
is
None
or
0
in
axes
:
if
val
.
distribution_strategy
!=
'fftw'
:
return_val
=
\
self
.
_
slicing_not
_fftw_
mpi
_transform
(
self
.
_
repack_to
_fftw_
and
_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
else
:
return_val
=
self
.
_
slicing_fftw_
mpi_transform
(
return_val
=
self
.
_mpi_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
else
:
return_val
=
self
.
_
slicing_
local_transform
(
return_val
=
self
.
_local_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
else
:
return_val
=
self
.
_
not_slicing
_transform
(
return_val
=
self
.
_
repack_to_fftw_and
_transform
(
val
,
domain
,
codomain
,
axes
,
**
kwargs
)
...
...
@@ -494,49 +475,26 @@ class FFTW(FFT):
class
FFTWTransformInfo
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
_info
,
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
if
pyfftw
is
None
:
raise
ImportError
(
"The module pyfftw is needed but not available."
)
# When the domain being transformed is not split across ranks, the
# mask will then have the same shape as the domain. The offset
# is set to False since every node will have index starting from 0. In
# the other case, we use the supplied local_shape_info to get the
# local_shape and offset
if
local_shape_info
is
None
:
self
.
cmask_domain
=
fftw_context
.
get_centering_mask
(
domain
.
paradict
[
'zerocenter'
],
domain
.
get_shape
(),
False
)
else
:
self
.
cmask_domain
=
fftw_context
.
get_centering_mask
(
domain
.
paradict
[
'zerocenter'
],
local_shape_info
[
0
],
local_shape_info
[
1
]
%
2
)
local_shape
,
local_offset_Q
)
if
local_shape_info
is
None
:
self
.
cmask_codomain
=
fftw_context
.
get_centering_mask
(
codomain
.
paradict
[
'zerocenter'
],
codomain
.
get_shape
(),
False
)
else
:
self
.
cmask_domain
=
fftw_context
.
get_centering_mask
(
codomain
.
paradict
[
'zerocenter'
],
local_shape_info
[
0
],
local_shape_info
[
1
]
%
2
)
local_shape
,
local_offset_Q
)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
self
.
sign
=
(
-
1
)
**
np
.
sum
(
np
.
array
(
domain
.
paradict
[
'zerocenter'
])
*
self
.
sign
=
(
-
1
)
**
np
.
sum
(
np
.
array
(
domain
.
paradict
[
'zerocenter'
])
*
np
.
array
(
codomain
.
paradict
[
'zerocenter'
])
*
(
np
.
array
(
domain
.
get_shape
())
//
2
%
2
)
)
(
np
.
array
(
domain
.
get_shape
())
//
2
%
2
))
@
property
def
cmask_domain
(
self
):
...
...
@@ -564,11 +522,14 @@ class FFTWTransformInfo(object):
class
FFTWLocalTransformInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
_info
,
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
super
(
FFTWLocalTransformInfo
,
self
).
__init__
(
domain
,
codomain
,
local_shape_info
,
fftw_context
,
**
kwargs
)
super
(
FFTWLocalTransformInfo
,
self
).
__init__
(
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
)
if
codomain
.
harmonic
:
self
.
_fftw_interface
=
pyfftw
.
interfaces
.
numpy_fft
.
fftn
else
:
...
...
@@ -585,11 +546,14 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class
FFTWMPITransfromInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
_info
,
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
super
(
FFTWMPITransfromInfo
,
self
).
__init__
(
domain
,
codomain
,
local_shape_info
,
fftw_context
,
**
kwargs
)
super
(
FFTWMPITransfromInfo
,
self
).
__init__
(
domain
,
codomain
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
)
# When the domain is 1-dimensional, reshape it so that it can
# accept input which is also augmented by 1.
if
len
(
domain
.
get_shape
())
==
1
:
...
...
test/test_nifty_spaces.py
View file @
1182e522
...
...
@@ -793,35 +793,35 @@ class Test_RG_Space(unittest.TestCase):
###############################################################################
@
parameterized
.
expand
(
itertools
.
product
([
True
],
#[True, False],
[
'
py
fftw'
]),
#DATAMODELS['rg_space']),
testcase_func_name
=
custom_name_func
)
def
test_get_random_values
(
self
,
harmonic
,
datamodel
):
x
=
rg_space
((
4
,
4
),
complexity
=
1
,
harmonic
=
harmonic
,
datamodel
=
datamodel
)
# pm1
data
=
x
.
get_random_values
(
random
=
'pm1'
)
flipped_data
=
flip
(
x
,
data
)
assert
(
check_almost_equality
(
x
,
data
,
flipped_data
))
# gau
data
=
x
.
get_random_values
(
random
=
'gau'
,
mean
=
4
+
3j
,
std
=
2
)
flipped_data
=
flip
(
x
,
data
)
assert
(
check_almost_equality
(
x
,
data
,
flipped_data
))
# uni
data
=
x
.
get_random_values
(
random
=
'uni'
,
vmin
=-
2
,
vmax
=
4
)
flipped_data
=
flip
(
x
,
data
)
assert
(
check_almost_equality
(
x
,
data
,
flipped_data
))
# syn
data
=
x
.
get_random_values
(
random
=
'syn'
,
spec
=
lambda
x
:
42
/
(
1
+
x
)
**
3
)
flipped_data
=
flip
(
x
,
data
)
assert
(
check_almost_equality
(
x
,
data
,
flipped_data
))
#
@parameterized.expand(
#
itertools.product([True], #[True, False],
#
['fftw']),
#
#DATAMODELS['rg_space']),
#
testcase_func_name=custom_name_func)
#
def test_get_random_values(self, harmonic, datamodel):
#
x = rg_space((4, 4), complexity=1, harmonic=harmonic,
#
datamodel=datamodel)
#
#
# pm1
#
data = x.get_random_values(random='pm1')
#
flipped_data = flip(x, data)
#
assert(check_almost_equality(x, data, flipped_data))
#
#
# gau
#
data = x.get_random_values(random='gau', mean=4 + 3j, std=2)
#
flipped_data = flip(x, data)
#
assert(check_almost_equality(x, data, flipped_data))
#
#
# uni
#
data = x.get_random_values(random='uni', vmin=-2, vmax=4)
#
flipped_data = flip(x, data)
#
assert(check_almost_equality(x, data, flipped_data))
#
#
# syn
#
data = x.get_random_values(random='syn',
#
spec=lambda x: 42 / (1 + x)**3)
#
flipped_data = flip(x, data)
#
assert(check_almost_equality(x, data, flipped_data))
###############################################################################
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment