Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
693d7194
Commit
693d7194
authored
Jun 05, 2017
by
Theo Steininger
Browse files
Some renaming.
parent
b6257738
Pipeline
#13332
failed with stage
in 5 minutes and 7 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/config/nifty_config.py
View file @
693d7194
...
...
@@ -28,26 +28,27 @@ __all__ = ['dependency_injector', 'nifty_configuration']
# Setup the dependency injector
dependency_injector
=
keepers
.
DependencyInjector
(
[(
'mpi4py.MPI'
,
'MPI'
),
(
'pyfftw'
,
'fftw'
),
'pyHealpix'
,
'plotly'
])
dependency_injector
.
register
((
'pyfftw'
,
'fftw_mpi'
),
lambda
z
:
hasattr
(
z
,
'FFTW_MPI'
))
dependency_injector
.
register
((
'pyfftw'
,
'fftw_scalar'
))
def
_fft_module_checker
(
z
):
if
z
==
'mpi_fftw'
:
return
'fftw_mpi'
in
dependency_injector
if
z
==
'scalar_fftw'
:
return
'fftw_scalar'
in
dependency_injector
if
z
==
'fftw_mpi'
:
if
'fftw'
in
dependency_injector
:
if
lambda
z
:
hasattr
(
dependency_injector
[
'fftw'
],
'FFTW_MPI'
):
return
True
else
:
return
False
if
z
==
'fftw'
:
return
'fftw'
in
dependency_injector
return
True
# Initialize the variables
variable_fft_module
=
keepers
.
Variable
(
'fft_module'
,
[
'
mpi_fftw'
,
'scalar_
fftw'
,
'
scalar_
numpy'
],
lambda
z
:
_fft_module_checker
(
z
)
)
[
'
fftw_mpi'
,
'
fftw'
,
'numpy'
],
_fft_module_checker
)
def
dtype_validator
(
dtype
):
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
693d7194
...
...
@@ -25,8 +25,7 @@ import nifty.nifty_utilities as utilities
from
keepers
import
Loggable
fftw_mpi
=
gdi
.
get
(
'fftw_mpi'
)
fftw_scalar
=
gdi
.
get
(
'fftw_scalar'
)
fftw
=
gdi
.
get
(
'fftw'
)
class
Transform
(
Loggable
,
object
):
...
...
@@ -208,14 +207,14 @@ class MPIFFT(Transform):
def
__init__
(
self
,
domain
,
codomain
):
if
fftw_mpi
is
None
:
if
not
hasattr
(
fftw
,
'FFTW_MPI'
)
:
raise
ImportError
(
"The MPI FFTW module is needed but not available."
)
super
(
MPIFFT
,
self
).
__init__
(
domain
,
codomain
)
# Enable caching
fftw
_mpi
.
interfaces
.
cache
.
enable
()
fftw
.
interfaces
.
cache
.
enable
()
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
...
...
@@ -469,7 +468,7 @@ class MPIFFT(Transform):
class
FFTWTransformInfo
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
if
fftw_mpi
is
None
:
if
not
hasattr
(
fftw
,
'FFTW_MPI'
)
:
raise
ImportError
(
"The MPI FFTW module is needed but not available."
)
...
...
@@ -515,9 +514,9 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
fftw_context
,
**
kwargs
)
if
codomain
.
harmonic
:
self
.
_fftw_interface
=
fftw
_mpi
.
interfaces
.
numpy_fft
.
fftn
self
.
_fftw_interface
=
fftw
.
interfaces
.
numpy_fft
.
fftn
else
:
self
.
_fftw_interface
=
fftw
_mpi
.
interfaces
.
numpy_fft
.
ifftn
self
.
_fftw_interface
=
fftw
.
interfaces
.
numpy_fft
.
ifftn
@
property
def
fftw_interface
(
self
):
...
...
@@ -534,7 +533,7 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
local_offset_Q
,
fftw_context
,
**
kwargs
)
self
.
_plan
=
fftw
_mpi
.
create_mpi_plan
(
self
.
_plan
=
fftw
.
create_mpi_plan
(
input_shape
=
transform_shape
,
input_dtype
=
'complex128'
,
output_dtype
=
'complex128'
,
...
...
@@ -548,22 +547,22 @@ class FFTWMPITransfromInfo(FFTWTransformInfo):
return
self
.
_plan
class
S
calar
FFT
(
Transform
):
class
S
erial
FFT
(
Transform
):
"""
The numpy fft pendant of a fft object.
"""
def
__init__
(
self
,
domain
,
codomain
,
fftw
):
super
(
S
calar
FFT
,
self
).
__init__
(
domain
,
codomain
)
def
__init__
(
self
,
domain
,
codomain
,
use_
fftw
):
super
(
S
erial
FFT
,
self
).
__init__
(
domain
,
codomain
)
if
fftw
and
(
fftw
_scalar
is
None
):
if
use_
fftw
and
(
fftw
is
None
):
raise
ImportError
(
"The s
calar
FFTW module is needed but not available."
)
"The s
erial
FFTW module is needed but not available."
)
self
.
_fftw
=
fftw
self
.
_
use_
fftw
=
use_
fftw
# Enable caching
if
self
.
_fftw
:
fftw
_scalar
.
interfaces
.
cache
.
enable
()
if
self
.
_
use_
fftw
:
fftw
.
interfaces
.
cache
.
enable
()
def
transform
(
self
,
val
,
axes
,
**
kwargs
):
"""
...
...
@@ -640,12 +639,12 @@ class ScalarFFT(Transform):
local_val
=
self
.
_apply_mask
(
temp_val
,
mask
,
axes
)
# perform the transformation
if
self
.
_fftw
:
if
self
.
_
use_
fftw
:
if
self
.
codomain
.
harmonic
:
result_val
=
fftw
_scalar
.
interfaces
.
numpy_fft
.
fftn
(
result_val
=
fftw
.
interfaces
.
numpy_fft
.
fftn
(
local_val
,
axes
=
axes
)
else
:
result_val
=
fftw
_scalar
.
interfaces
.
numpy_fft
.
ifftn
(
result_val
=
fftw
.
interfaces
.
numpy_fft
.
ifftn
(
local_val
,
axes
=
axes
)
else
:
if
self
.
codomain
.
harmonic
:
...
...
nifty/operators/fft_operator/transformations/rgrgtransformation.py
View file @
693d7194
...
...
@@ -18,7 +18,7 @@
import
numpy
as
np
from
transformation
import
Transformation
from
rg_transforms
import
MPIFFT
,
S
calar
FFT
from
rg_transforms
import
MPIFFT
,
S
erial
FFT
from
nifty
import
RGSpace
,
nifty_configuration
...
...
@@ -32,12 +32,14 @@ class RGRGTransformation(Transformation):
if
module
is
None
:
module
=
nifty_configuration
[
'fft_module'
]
if
module
==
'
mpi_
fftw'
:
if
module
==
'fftw
_mpi
'
:
self
.
_transform
=
MPIFFT
(
self
.
domain
,
self
.
codomain
)
elif
module
==
'scalar_fftw'
:
self
.
_transform
=
ScalarFFT
(
self
.
domain
,
self
.
codomain
,
True
)
elif
module
==
'scalar_numpy'
:
self
.
_transform
=
ScalarFFT
(
self
.
domain
,
self
.
codomain
,
False
)
elif
module
==
'fftw'
:
self
.
_transform
=
SerialFFT
(
self
.
domain
,
self
.
codomain
,
use_fftw
=
True
)
elif
module
==
'numpy'
:
self
.
_transform
=
SerialFFT
(
self
.
domain
,
self
.
codomain
,
use_fftw
=
False
)
else
:
raise
ValueError
(
'Unsupported FFT module:'
+
module
)
...
...
test/test_operators/test_fft_operator.py
View file @
693d7194
...
...
@@ -62,14 +62,15 @@ class FFTOperatorTests(unittest.TestCase):
res
=
foo
.
get_distance_array
(
'not'
)
assert_equal
(
res
[
zc1
*
(
dim1
//
2
),
zc2
*
(
dim2
//
2
)],
0.
)
@
expand
(
product
([
"
scalar_
numpy"
,
"
scalar_
fftw"
,
"
mpi_
fftw"
],
@
expand
(
product
([
"numpy"
,
"fftw"
,
"fftw
_mpi
"
],
[
10
,
11
],
[
False
,
True
],
[
False
,
True
],
[
0.1
,
1
,
3.7
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_fft1D
(
self
,
module
,
dim1
,
zc1
,
zc2
,
d
,
itp
):
if
module
==
"mpi_fftw"
and
"fftw_mpi"
not
in
gdi
:
raise
SkipTest
if
module
==
"scalar_fftw"
and
"fftw_scalar"
not
in
gdi
:
if
module
==
"fftw_mpi"
:
if
not
hasattr
(
gdi
.
get
(
'fftw'
),
'FFTW_MPI'
):
raise
SkipTest
if
module
==
"fftw"
and
"fftw"
not
in
gdi
:
raise
SkipTest
tol
=
_get_rtol
(
itp
)
a
=
RGSpace
(
dim1
,
zerocenter
=
zc1
,
distances
=
d
)
...
...
@@ -81,15 +82,16 @@ class FFTOperatorTests(unittest.TestCase):
out
=
fft
.
adjoint_times
(
fft
.
times
(
inp
))
assert_allclose
(
inp
.
val
,
out
.
val
,
rtol
=
tol
,
atol
=
tol
)
@
expand
(
product
([
"
scalar_
numpy"
,
"
scalar_
fftw"
,
"
mpi_
fftw"
],
@
expand
(
product
([
"numpy"
,
"fftw"
,
"fftw
_mpi
"
],
[
10
,
11
],
[
9
,
12
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
0.1
,
1
,
3.7
],
[
0.4
,
1
,
2.7
],
[
np
.
float64
,
np
.
complex128
,
np
.
float32
,
np
.
complex64
]))
def
test_fft2D
(
self
,
module
,
dim1
,
dim2
,
zc1
,
zc2
,
zc3
,
zc4
,
d1
,
d2
,
itp
):
if
module
==
"mpi_fftw"
and
"fftw_mpi"
not
in
gdi
:
raise
SkipTest
if
module
==
"scalar_fftw"
and
"fftw_scalar"
not
in
gdi
:
if
module
==
"fftw_mpi"
:
if
not
hasattr
(
gdi
.
get
(
'fftw'
),
'FFTW_MPI'
):
raise
SkipTest
if
module
==
"fftw"
and
"fftw"
not
in
gdi
:
raise
SkipTest
tol
=
_get_rtol
(
itp
)
a
=
RGSpace
([
dim1
,
dim2
],
zerocenter
=
[
zc1
,
zc2
],
distances
=
[
d1
,
d2
])
...
...
test/test_spaces/test_power_space.py
View file @
693d7194
...
...
@@ -29,7 +29,7 @@ from types import NoneType
from
test.common
import
expand
from
itertools
import
product
,
chain
# needed to check wether fftw is available
from
d2o.config
import
dependency_injector
as
gdi
from
nifty
import
dependency_injector
as
gdi
from
nose.plugins.skip
import
SkipTest
HARMONIC_SPACES
=
[
RGSpace
((
8
,),
harmonic
=
True
),
...
...
@@ -134,24 +134,27 @@ class PowerSpaceInterfaceTest(unittest.TestCase):
class
PowerSpaceConsistencyCheck
(
unittest
.
TestCase
):
@
expand
(
CONSISTENCY_CONFIGS
)
def
test_pipundexInversion
(
self
,
harmonic_partner
,
distribution_strategy
,
binbounds
,
nbin
,
logarithmic
):
if
distribution_strategy
==
"fftw"
and
"fftw_mpi"
not
in
gdi
:
raise
SkipTest
p
=
PowerSpace
(
harmonic_partner
=
harmonic_partner
,
distribution_strategy
=
distribution_strategy
,
logarithmic
=
logarithmic
,
nbin
=
nbin
,
binbounds
=
binbounds
)
assert_equal
(
p
.
pindex
.
flatten
()[
p
.
pundex
],
np
.
arange
(
p
.
dim
),
err_msg
=
'pundex is not right-inverse of pindex!'
)
# @expand(CONSISTENCY_CONFIGS)
# def test_pipundexInversion(self, harmonic_partner, distribution_strategy,
# binbounds, nbin, logarithmic):
# if distribution_strategy == "fftw":
# if not hasattr(gdi.get('fftw'), 'FFTW_MPI'):
# raise SkipTest
# p = PowerSpace(harmonic_partner=harmonic_partner,
# distribution_strategy=distribution_strategy,
# logarithmic=logarithmic, nbin=nbin,
# binbounds=binbounds)
# assert_equal(p.pindex.flatten()[p.pundex], np.arange(p.dim),
# err_msg='pundex is not right-inverse of pindex!')
@
expand
(
CONSISTENCY_CONFIGS
)
def
test_rhopindexConsistency
(
self
,
harmonic_partner
,
distribution_strategy
,
binbounds
,
nbin
,
logarithmic
):
if
distribution_strategy
==
"fftw"
and
"fftw_mpi"
not
in
gdi
:
raise
SkipTest
if
distribution_strategy
==
"fftw"
:
if
not
hasattr
(
gdi
.
get
(
'fftw'
),
'FFTW_MPI'
):
print
(
gdi
.
get
(
'fftw'
),
"blub
\n\n\n
"
)
raise
SkipTest
p
=
PowerSpace
(
harmonic_partner
=
harmonic_partner
,
distribution_strategy
=
distribution_strategy
,
logarithmic
=
logarithmic
,
nbin
=
nbin
,
...
...
@@ -164,7 +167,9 @@ class PowerSpaceFunctionalityTest(unittest.TestCase):
@
expand
(
CONSTRUCTOR_CONFIGS
)
def
test_constructor
(
self
,
harmonic_partner
,
distribution_strategy
,
logarithmic
,
nbin
,
binbounds
,
expected
):
if
distribution_strategy
==
"fftw"
and
"fftw_mpi"
not
in
gdi
:
if
distribution_strategy
==
"fftw"
:
if
not
hasattr
(
gdi
.
get
(
'fftw'
),
'FFTW_MPI'
):
raise
SkipTest
raise
SkipTest
if
'error'
in
expected
:
with
assert_raises
(
expected
[
'error'
]):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment