Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
b6257738
Commit
b6257738
authored
Jun 02, 2017
by
Martin Reinecke
Browse files
allow explicit selection of scalar FFT module
parent
dcebff53
Pipeline
#13265
passed with stage
in 5 minutes and 19 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/config/nifty_config.py
View file @
b6257738
...
...
@@ -35,12 +35,19 @@ dependency_injector.register(('pyfftw', 'fftw_mpi'),
lambda
z
:
hasattr
(
z
,
'FFTW_MPI'
))
dependency_injector
.
register
((
'pyfftw'
,
'fftw_scalar'
))
def
_fft_module_checker
(
z
):
if
z
==
'mpi_fftw'
:
return
'fftw_mpi'
in
dependency_injector
if
z
==
'scalar_fftw'
:
return
'fftw_scalar'
in
dependency_injector
return
True
# Initialize the variables
variable_fft_module
=
keepers
.
Variable
(
'fft_module'
,
[
'mpi'
,
'scalar'
],
lambda
z
:
((
'fftw_mpi'
in
dependency_injector
)
if
z
==
'mpi'
else
True
))
[
'mpi_fftw'
,
'scalar_fftw'
,
'scalar_numpy'
],
lambda
z
:
_fft_module_checker
(
z
))
def
dtype_validator
(
dtype
):
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
b6257738
...
...
@@ -553,6 +553,17 @@ class ScalarFFT(Transform):
The numpy fft pendant of a fft object.
"""
def
__init__
(
self
,
domain
,
codomain
,
fftw
):
super
(
ScalarFFT
,
self
).
__init__
(
domain
,
codomain
)
if
fftw
and
(
fftw_scalar
is
None
):
raise
ImportError
(
"The scalar FFTW module is needed but not available."
)
self
.
_fftw
=
fftw
# Enable caching
if
self
.
_fftw
:
fftw_scalar
.
interfaces
.
cache
.
enable
()
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
"""
...
...
@@ -575,9 +586,6 @@ class ScalarFFT(Transform):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Enable caching
if
fftw_scalar
is
not
None
:
fftw_scalar
.
interfaces
.
cache
.
enable
()
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
...
...
@@ -632,7 +640,7 @@ class ScalarFFT(Transform):
local_val
=
self
.
_apply_mask
(
temp_val
,
mask
,
axes
)
# perform the transformation
if
fftw_scalar
is
not
None
:
if
self
.
_fftw
:
if
self
.
codomain
.
harmonic
:
result_val
=
fftw_scalar
.
interfaces
.
numpy_fft
.
fftn
(
local_val
,
axes
=
axes
)
...
...
nifty/operators/fft_operator/transformations/rgrgtransformation.py
View file @
b6257738
...
...
@@ -30,20 +30,16 @@ class RGRGTransformation(Transformation):
super
(
RGRGTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
if
module
is
None
:
if
nifty_configuration
[
'fft_module'
]
==
'mpi'
:
self
.
_transform
=
MPIFFT
(
self
.
domain
,
self
.
codomain
)
elif
nifty_configuration
[
'fft_module'
]
==
'scalar'
:
self
.
_transform
=
ScalarFFT
(
self
.
domain
,
self
.
codomain
)
else
:
raise
ValueError
(
'Unsupported default FFT module:'
+
nifty_configuration
[
'fft_module'
])
module
=
nifty_configuration
[
'fft_module'
]
if
module
==
'mpi_fftw'
:
self
.
_transform
=
MPIFFT
(
self
.
domain
,
self
.
codomain
)
elif
module
==
'scalar_fftw'
:
self
.
_transform
=
ScalarFFT
(
self
.
domain
,
self
.
codomain
,
True
)
elif
module
==
'scalar_numpy'
:
self
.
_transform
=
ScalarFFT
(
self
.
domain
,
self
.
codomain
,
False
)
else
:
if
module
==
'mpi'
:
self
.
_transform
=
MPIFFT
(
self
.
domain
,
self
.
codomain
)
elif
module
==
'scalar'
:
self
.
_transform
=
ScalarFFT
(
self
.
domain
,
self
.
codomain
)
else
:
raise
ValueError
(
'Unsupported FFT module:'
+
module
)
raise
ValueError
(
'Unsupported FFT module:'
+
module
)
# ---Mandatory properties and methods---
...
...
test/test_operators/test_fft_operator.py
View file @
b6257738
...
...
@@ -20,7 +20,7 @@ import unittest
import
numpy
as
np
from
numpy.testing
import
assert_equal
,
\
assert_allclose
from
nifty.config
import
dependency_injector
as
di
from
nifty.config
import
dependency_injector
as
g
di
from
nifty
import
Field
,
\
RGSpace
,
\
LMSpace
,
\
...
...
@@ -62,11 +62,14 @@ class FFTOperatorTests(unittest.TestCase):
res
=
foo
.
get_distance_array
(
'not'
)
assert_equal
(
res
[
zc1
*
(
dim1
//
2
),
zc2
*
(
dim2
//
2
)],
0.
)
@
expand
(
product
([
"scalar"
,
"mpi"
],
[
10
,
11
],
[
False
,
True
],
[
False
,
True
],
@
expand
(
product
([
"scalar_numpy"
,
"scalar_fftw"
,
"mpi_fftw"
],
[
10
,
11
],
[
False
,
True
],
[
False
,
True
],
[
0.1
,
1
,
3.7
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_fft1D
(
self
,
module
,
dim1
,
zc1
,
zc2
,
d
,
itp
):
if
module
==
"mpi"
and
"fftw_mpi"
not
in
di
:
if
module
==
"mpi_fftw"
and
"fftw_mpi"
not
in
gdi
:
raise
SkipTest
if
module
==
"scalar_fftw"
and
"fftw_scalar"
not
in
gdi
:
raise
SkipTest
tol
=
_get_rtol
(
itp
)
a
=
RGSpace
(
dim1
,
zerocenter
=
zc1
,
distances
=
d
)
...
...
@@ -78,12 +81,15 @@ class FFTOperatorTests(unittest.TestCase):
out
=
fft
.
adjoint_times
(
fft
.
times
(
inp
))
assert_allclose
(
inp
.
val
,
out
.
val
,
rtol
=
tol
,
atol
=
tol
)
@
expand
(
product
([
"scalar"
,
"mpi"
],
[
10
,
11
],
[
9
,
12
],
[
False
,
True
],
@
expand
(
product
([
"scalar_numpy"
,
"scalar_fftw"
,
"mpi_fftw"
],
[
10
,
11
],
[
9
,
12
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
0.1
,
1
,
3.7
],
[
0.4
,
1
,
2.7
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_fft2D
(
self
,
module
,
dim1
,
dim2
,
zc1
,
zc2
,
zc3
,
zc4
,
d1
,
d2
,
itp
):
if
module
==
"mpi"
and
"fftw_mpi"
not
in
di
:
if
module
==
"mpi_fftw"
and
"fftw_mpi"
not
in
gdi
:
raise
SkipTest
if
module
==
"scalar_fftw"
and
"fftw_scalar"
not
in
gdi
:
raise
SkipTest
tol
=
_get_rtol
(
itp
)
a
=
RGSpace
([
dim1
,
dim2
],
zerocenter
=
[
zc1
,
zc2
],
distances
=
[
d1
,
d2
])
...
...
@@ -99,7 +105,7 @@ class FFTOperatorTests(unittest.TestCase):
@
expand
(
product
([
0
,
3
,
6
,
11
,
30
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_sht
(
self
,
lm
,
tp
):
if
'pyHealpix'
not
in
di
:
if
'pyHealpix'
not
in
g
di
:
raise
SkipTest
tol
=
_get_rtol
(
tp
)
a
=
LMSpace
(
lmax
=
lm
)
...
...
@@ -113,7 +119,7 @@ class FFTOperatorTests(unittest.TestCase):
@
expand
(
product
([
128
,
256
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_sht2
(
self
,
lm
,
tp
):
if
'pyHealpix'
not
in
di
:
if
'pyHealpix'
not
in
g
di
:
raise
SkipTest
a
=
LMSpace
(
lmax
=
lm
)
b
=
HPSpace
(
nside
=
lm
//
2
)
...
...
@@ -126,7 +132,7 @@ class FFTOperatorTests(unittest.TestCase):
@
expand
(
product
([
128
,
256
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_dotsht
(
self
,
lm
,
tp
):
if
'pyHealpix'
not
in
di
:
if
'pyHealpix'
not
in
g
di
:
raise
SkipTest
tol
=
_get_rtol
(
tp
)
a
=
LMSpace
(
lmax
=
lm
)
...
...
@@ -142,7 +148,7 @@ class FFTOperatorTests(unittest.TestCase):
@
expand
(
product
([
128
,
256
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_dotsht2
(
self
,
lm
,
tp
):
if
'pyHealpix'
not
in
di
:
if
'pyHealpix'
not
in
g
di
:
raise
SkipTest
tol
=
_get_rtol
(
tp
)
a
=
LMSpace
(
lmax
=
lm
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment