Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
039a629a
Commit
039a629a
authored
May 17, 2018
by
Martin Reinecke
Browse files
streamline the FFT interface
parent
5ed8f324
Pipeline
#29434
passed with stages
in 11 minutes and 52 seconds
Changes
2
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
nifty4/operators/fft_operator.py
View file @
039a629a
...
@@ -61,9 +61,7 @@ class FFTOperator(LinearOperator):
...
@@ -61,9 +61,7 @@ class FFTOperator(LinearOperator):
adom
.
check_codomain
(
target
)
adom
.
check_codomain
(
target
)
target
.
check_codomain
(
adom
)
target
.
check_codomain
(
adom
)
import
pyfftw
utilities
.
fft_prep
()
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
def
apply
(
self
,
x
,
mode
):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
self
.
_check_input
(
x
,
mode
)
...
@@ -74,7 +72,6 @@ class FFTOperator(LinearOperator):
...
@@ -74,7 +72,6 @@ class FFTOperator(LinearOperator):
return
self
.
_apply_cartesian
(
x
,
mode
)
return
self
.
_apply_cartesian
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
def
_apply_cartesian
(
self
,
x
,
mode
):
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
_space
]
axes
=
x
.
domain
.
axes
[
self
.
_space
]
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
oldax
=
dobj
.
distaxis
(
x
.
val
)
oldax
=
dobj
.
distaxis
(
x
.
val
)
...
@@ -110,7 +107,7 @@ class FFTOperator(LinearOperator):
...
@@ -110,7 +107,7 @@ class FFTOperator(LinearOperator):
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
utilities
.
my_
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
tmp
=
dobj
.
transpose
(
tmp
)
...
...
nifty4/utilities.py
View file @
039a629a
...
@@ -23,7 +23,8 @@ import abc
...
@@ -23,7 +23,8 @@ import abc
from
future.utils
import
with_metaclass
from
future.utils
import
with_metaclass
__all__
=
[
"get_slice_list"
,
"safe_cast"
,
"parse_spaces"
,
"infer_space"
,
__all__
=
[
"get_slice_list"
,
"safe_cast"
,
"parse_spaces"
,
"infer_space"
,
"memo"
,
"NiftyMetaBase"
,
"hartley"
,
"my_fftn_r2c"
]
"memo"
,
"NiftyMetaBase"
,
"fft_prep"
,
"hartley"
,
"my_fftn_r2c"
,
"my_fftn"
]
def
get_slice_list
(
shape
,
axes
):
def
get_slice_list
(
shape
,
axes
):
...
@@ -167,6 +168,17 @@ def nthreads():
...
@@ -167,6 +168,17 @@ def nthreads():
return
nthreads
.
_val
return
nthreads
.
_val
nthreads
.
_val
=
None
nthreads
.
_val
=
None
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args
=
dict
(
planner_effort
=
'FFTW_ESTIMATE'
)
def
fft_prep
():
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
def
hartley
(
a
,
axes
=
None
):
def
hartley
(
a
,
axes
=
None
):
# Check if the axes provided are valid given the shape
# Check if the axes provided are valid given the shape
...
@@ -177,7 +189,7 @@ def hartley(a, axes=None):
...
@@ -177,7 +189,7 @@ def hartley(a, axes=None):
raise
TypeError
(
"Hartley transform requires real-valued arrays."
)
raise
TypeError
(
"Hartley transform requires real-valued arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
())
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
()
,
**
_fft_extra_args
)
def
_fill_array
(
tmp
,
res
,
axes
):
def
_fill_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
if
axes
is
None
:
...
@@ -219,7 +231,7 @@ def my_fftn_r2c(a, axes=None):
...
@@ -219,7 +231,7 @@ def my_fftn_r2c(a, axes=None):
raise
TypeError
(
"Transform requires real-valued input arrays."
)
raise
TypeError
(
"Transform requires real-valued input arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
())
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
()
,
**
_fft_extra_args
)
def
_fill_complex_array
(
tmp
,
res
,
axes
):
def
_fill_complex_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
if
axes
is
None
:
...
@@ -250,3 +262,8 @@ def my_fftn_r2c(a, axes=None):
...
@@ -250,3 +262,8 @@ def my_fftn_r2c(a, axes=None):
return
res
return
res
return
_fill_complex_array
(
tmp
,
np
.
empty_like
(
a
,
dtype
=
tmp
.
dtype
),
axes
)
return
_fill_complex_array
(
tmp
,
np
.
empty_like
(
a
,
dtype
=
tmp
.
dtype
),
axes
)
def
my_fftn
(
a
,
axes
=
None
):
from
pyfftw.interfaces.numpy_fft
import
fftn
return
fftn
(
a
,
axes
=
axes
,
**
_fft_extra_args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment