Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
b0d67512
Commit
b0d67512
authored
May 30, 2017
by
Martin Reinecke
Browse files
renaming, WIP
parent
88d6e60d
Pipeline
#13115
passed with stage
in 5 minutes and 12 seconds
Changes
6
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/config/nifty_config.py
View file @
b0d67512
...
...
@@ -31,15 +31,15 @@ dependency_injector = keepers.DependencyInjector(
'pyHealpix'
,
'plotly'
])
dependency_injector
.
register
(
'pyfftw'
,
lambda
z
:
hasattr
(
z
,
'FFTW_MPI'
))
dependency_injector
.
register
((
'pyfftw'
,
'
py
fftw_scalar'
))
dependency_injector
.
register
(
(
'pyfftw'
,
'fftw_mpi'
),
lambda
z
:
hasattr
(
z
,
'FFTW_MPI'
))
dependency_injector
.
register
((
'pyfftw'
,
'fftw_scalar'
))
# Initialize the variables
variable_fft_module
=
keepers
.
Variable
(
'fft_module'
,
[
'
fftw'
,
'numpy
'
],
lambda
z
:
((
'
py
fftw'
in
dependency_injector
)
if
z
==
'
fftw
'
else
True
))
[
'
mpi'
,
'scalar
'
],
lambda
z
:
((
'fftw
_mpi
'
in
dependency_injector
)
if
z
==
'
mpi
'
else
True
))
def
dtype_validator
(
dtype
):
...
...
nifty/operators/fft_operator/fft_operator.py
View file @
b0d67512
...
...
@@ -63,9 +63,10 @@ class FFTOperator(LinearOperator):
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is
always available, but "fftw" offers higher performance and
parallelization. For sphere-related domains, only "pyHealpix" is
For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
always available (using pyfftw if available, else numpy.fft), and "mpi"
requires pyfftw and offers MPI parallelization.
For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional)
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
b0d67512
...
...
@@ -25,8 +25,8 @@ import nifty.nifty_utilities as utilities
from
keepers
import
Loggable
py
fftw
=
gdi
.
get
(
'
py
fftw'
)
py
fftw_scalar
=
gdi
.
get
(
'
py
fftw_scalar'
)
fftw
_mpi
=
gdi
.
get
(
'fftw
_mpi
'
)
fftw_scalar
=
gdi
.
get
(
'fftw_scalar'
)
class
Transform
(
Loggable
,
object
):
...
...
@@ -201,20 +201,21 @@ class Transform(Loggable, object):
raise
NotImplementedError
class
FFT
W
(
Transform
):
class
MPI
FFT
(
Transform
):
"""
The
pyfftw
pendant of a fft object.
The
MPI-parallel FFTW
pendant of a fft object.
"""
def
__init__
(
self
,
domain
,
codomain
):
if
'pyfftw'
not
in
gdi
:
raise
ImportError
(
"The module pyfftw is needed but not available."
)
if
'fftw_mpi'
not
in
gdi
:
raise
ImportError
(
"The MPI FFTW module is needed but not available."
)
super
(
FFT
W
,
self
).
__init__
(
domain
,
codomain
)
super
(
MPI
FFT
,
self
).
__init__
(
domain
,
codomain
)
# Enable caching
for pyfftw.interfaces
py
fftw
.
interfaces
.
cache
.
enable
()
# Enable caching
fftw
_mpi
.
interfaces
.
cache
.
enable
()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
...
...
@@ -410,7 +411,7 @@ class FFTW(Transform):
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
"""
The
pyfftw
transform function.
The
MPI-parallel FFTW
transform function.
Parameters
----------
...
...
@@ -468,8 +469,9 @@ class FFTW(Transform):
class
FFTWTransformInfo
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
if
pyfftw
is
None
:
raise
ImportError
(
"The module pyfftw is needed but not available."
)
if
fftw_mpi
is
None
:
raise
ImportError
(
"The MPI FFTW module is needed but not available."
)
shape
=
(
local_shape
if
axes
is
None
else
[
y
for
x
,
y
in
enumerate
(
local_shape
)
if
x
in
axes
])
...
...
@@ -513,9 +515,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context
,
**
kwargs
)
if
codomain
.
harmonic
:
self
.
_fftw_interface
=
py
fftw
.
interfaces
.
numpy_fft
.
fftn
self
.
_fftw_interface
=
fftw
_mpi
.
interfaces
.
numpy_fft
.
fftn
else
:
self
.
_fftw_interface
=
py
fftw
.
interfaces
.
numpy_fft
.
ifftn
self
.
_fftw_interface
=
fftw
_mpi
.
interfaces
.
numpy_fft
.
ifftn
@
property
def
fftw_interface
(
self
):
...
...
@@ -532,7 +534,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q
,
fftw_context
,
**
kwargs
)
self
.
_plan
=
py
fftw
.
create_mpi_plan
(
self
.
_plan
=
fftw
_mpi
.
create_mpi_plan
(
input_shape
=
transform_shape
,
input_dtype
=
'complex128'
,
output_dtype
=
'complex128'
,
...
...
@@ -546,7 +548,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return
self
.
_plan
class
NUMPY
FFT
(
Transform
):
class
Scalar
FFT
(
Transform
):
"""
The numpy fft pendant of a fft object.
...
...
@@ -554,7 +556,7 @@ class NUMPYFFT(Transform):
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
"""
The
pyfftw
transform function.
The
scalar FFT
transform function.
Parameters
----------
...
...
@@ -573,9 +575,9 @@ class NUMPYFFT(Transform):
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Enable caching
for pyfftw_scalar.interfaces
if
'
py
fftw_scalar'
in
gdi
:
py
fftw_scalar
.
interfaces
.
cache
.
enable
()
# Enable caching
if
'fftw_scalar'
in
gdi
:
fftw_scalar
.
interfaces
.
cache
.
enable
()
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
...
...
@@ -630,11 +632,11 @@ class NUMPYFFT(Transform):
local_val
=
self
.
_apply_mask
(
temp_val
,
mask
,
axes
)
# perform the transformation
if
'
py
fftw_scalar'
in
gdi
:
if
'fftw_scalar'
in
gdi
:
if
self
.
codomain
.
harmonic
:
result_val
=
py
fftw_scalar
.
interfaces
.
numpy_fft
.
fftn
(
local_val
,
axes
=
axes
)
result_val
=
fftw_scalar
.
interfaces
.
numpy_fft
.
fftn
(
local_val
,
axes
=
axes
)
else
:
result_val
=
py
fftw_scalar
.
interfaces
.
numpy_fft
.
ifftn
(
local_val
,
axes
=
axes
)
result_val
=
fftw_scalar
.
interfaces
.
numpy_fft
.
ifftn
(
local_val
,
axes
=
axes
)
else
:
if
self
.
codomain
.
harmonic
:
result_val
=
np
.
fft
.
fftn
(
local_val
,
axes
=
axes
)
...
...
nifty/operators/fft_operator/transformations/rgrgtransformation.py
View file @
b0d67512
...
...
@@ -18,7 +18,7 @@
import
numpy
as
np
from
transformation
import
Transformation
from
rg_transforms
import
FFT
W
,
NUMPY
FFT
from
rg_transforms
import
MPI
FFT
,
Scalar
FFT
from
nifty
import
RGSpace
,
nifty_configuration
...
...
@@ -30,18 +30,18 @@ class RGRGTransformation(Transformation):
super
(
RGRGTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
if
module
is
None
:
if
nifty_configuration
[
'fft_module'
]
==
'
fftw
'
:
self
.
_transform
=
FFT
W
(
self
.
domain
,
self
.
codomain
)
elif
nifty_configuration
[
'fft_module'
]
==
'
numpy
'
:
self
.
_transform
=
NUMPY
FFT
(
self
.
domain
,
self
.
codomain
)
if
nifty_configuration
[
'fft_module'
]
==
'
mpi
'
:
self
.
_transform
=
MPI
FFT
(
self
.
domain
,
self
.
codomain
)
elif
nifty_configuration
[
'fft_module'
]
==
'
scalar
'
:
self
.
_transform
=
Scalar
FFT
(
self
.
domain
,
self
.
codomain
)
else
:
raise
ValueError
(
'Unsupported default FFT module:'
+
nifty_configuration
[
'fft_module'
])
else
:
if
module
==
'
fftw
'
:
self
.
_transform
=
FFT
W
(
self
.
domain
,
self
.
codomain
)
elif
module
==
'
numpy
'
:
self
.
_transform
=
NUMPY
FFT
(
self
.
domain
,
self
.
codomain
)
if
module
==
'
mpi
'
:
self
.
_transform
=
MPI
FFT
(
self
.
domain
,
self
.
codomain
)
elif
module
==
'
scalar
'
:
self
.
_transform
=
Scalar
FFT
(
self
.
domain
,
self
.
codomain
)
else
:
raise
ValueError
(
'Unsupported FFT module:'
+
module
)
...
...
test/test_operators/test_fft_operator.py
View file @
b0d67512
...
...
@@ -62,11 +62,11 @@ class FFTOperatorTests(unittest.TestCase):
res
=
foo
.
get_distance_array
(
'not'
)
assert_equal
(
res
[
zc1
*
(
dim1
//
2
),
zc2
*
(
dim2
//
2
)],
0.
)
@
expand
(
product
([
"
numpy"
,
"fftw
"
],
[
10
,
11
],
[
False
,
True
],
[
False
,
True
],
@
expand
(
product
([
"
scalar"
,
"mpi
"
],
[
10
,
11
],
[
False
,
True
],
[
False
,
True
],
[
0.1
,
1
,
3.7
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_fft1D
(
self
,
module
,
dim1
,
zc1
,
zc2
,
d
,
itp
):
if
module
==
"
fftw
"
and
"
py
fftw"
not
in
di
:
if
module
==
"
mpi
"
and
"fftw
_mpi
"
not
in
di
:
raise
SkipTest
tol
=
_get_rtol
(
itp
)
a
=
RGSpace
(
dim1
,
zerocenter
=
zc1
,
distances
=
d
)
...
...
@@ -78,12 +78,12 @@ class FFTOperatorTests(unittest.TestCase):
out
=
fft
.
adjoint_times
(
fft
.
times
(
inp
))
assert_allclose
(
inp
.
val
,
out
.
val
,
rtol
=
tol
,
atol
=
tol
)
@
expand
(
product
([
"
numpy"
,
"fftw
"
],
[
10
,
11
],
[
9
,
12
],
[
False
,
True
],
@
expand
(
product
([
"
scalar"
,
"mpi
"
],
[
10
,
11
],
[
9
,
12
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
0.1
,
1
,
3.7
],
[
0.4
,
1
,
2.7
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_fft2D
(
self
,
module
,
dim1
,
dim2
,
zc1
,
zc2
,
zc3
,
zc4
,
d1
,
d2
,
itp
):
if
module
==
"
fftw
"
and
"
py
fftw"
not
in
di
:
if
module
==
"
mpi
"
and
"fftw
_mpi
"
not
in
di
:
raise
SkipTest
tol
=
_get_rtol
(
itp
)
a
=
RGSpace
([
dim1
,
dim2
],
zerocenter
=
[
zc1
,
zc2
],
distances
=
[
d1
,
d2
])
...
...
test/test_spaces/test_power_space.py
View file @
b0d67512
...
...
@@ -32,22 +32,22 @@ from itertools import product, chain
from
d2o.config
import
dependency_injector
as
gdi
HARMONIC_SPACES
=
[
RGSpace
((
8
,),
harmonic
=
True
),
RGSpace
((
7
,),
harmonic
=
True
,
zerocenter
=
True
),
RGSpace
((
8
,),
harmonic
=
True
,
zerocenter
=
True
),
RGSpace
((
7
,
8
),
harmonic
=
True
),
RGSpace
((
7
,),
harmonic
=
True
,
zerocenter
=
True
),
RGSpace
((
8
,),
harmonic
=
True
,
zerocenter
=
True
),
RGSpace
((
7
,
8
),
harmonic
=
True
),
RGSpace
((
7
,
8
),
harmonic
=
True
,
zerocenter
=
True
),
RGSpace
((
6
,
6
),
harmonic
=
True
,
zerocenter
=
True
),
RGSpace
((
7
,
5
),
harmonic
=
True
,
zerocenter
=
True
),
RGSpace
((
5
,
5
),
harmonic
=
True
),
RGSpace
((
5
,
5
),
harmonic
=
True
),
RGSpace
((
4
,
5
,
7
),
harmonic
=
True
),
RGSpace
((
4
,
5
,
7
),
harmonic
=
True
,
zerocenter
=
True
),
LMSpace
(
6
),
LMSpace
(
9
)]
#Try all sensible kinds of combinations of spaces, distributuion strategy and
#Try all sensible kinds of combinations of spaces, distributuion strategy and
#binning parameters
_maybe_fftw
=
[
"fftw"
]
if
(
'
py
fftw'
in
gdi
)
else
[]
_maybe_fftw
=
[
"fftw"
]
if
(
'fftw
_mpi
'
in
gdi
)
else
[]
CONSISTENCY_CONFIGS_IMPLICIT
=
product
(
HARMONIC_SPACES
,
[
"not"
,
"equal"
]
+
_maybe_fftw
,
[
None
],
[
None
,
3
,
4
],
[
True
,
False
])
CONSISTENCY_CONFIGS_EXPLICIT
=
product
(
HARMONIC_SPACES
,
[
"not"
,
"equal"
]
+
_maybe_fftw
,
[[
0.
,
1.3
]],[
None
],[
False
])
...
...
@@ -138,13 +138,13 @@ class PowerSpaceConsistencyCheck(unittest.TestCase):
binbounds
=
binbounds
)
assert_equal
(
p
.
pindex
.
flatten
()[
p
.
pundex
],
np
.
arange
(
p
.
dim
),
err_msg
=
'pundex is not right-inverse of pindex!'
)
@
expand
(
CONSISTENCY_CONFIGS
)
def
test_rhopindexConsistency
(
self
,
harmonic_partner
,
distribution_strategy
,
binbounds
,
nbin
,
logarithmic
):
assert_equal
(
p
.
pindex
.
flatten
().
bincount
(),
p
.
rho
,
err_msg
=
'rho is not equal to pindex degeneracy'
)
class
PowerSpaceFunctionalityTest
(
unittest
.
TestCase
):
@
expand
(
CONSISTENCY_CONFIGS
)
def
test_constructor
(
self
,
harmonic_partner
,
distribution_strategy
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment