Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
039a629a
Commit
039a629a
authored
May 17, 2018
by
Martin Reinecke
Browse files
streamline the FFT interface
parent
5ed8f324
Pipeline
#29434
passed with stages
in 11 minutes and 52 seconds
Changes
2
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
nifty4/operators/fft_operator.py
View file @
039a629a
...
...
@@ -61,9 +61,7 @@ class FFTOperator(LinearOperator):
adom
.
check_codomain
(
target
)
target
.
check_codomain
(
adom
)
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
utilities
.
fft_prep
()
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -74,7 +72,6 @@ class FFTOperator(LinearOperator):
return
self
.
_apply_cartesian
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
_space
]
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
oldax
=
dobj
.
distaxis
(
x
.
val
)
...
...
@@ -110,7 +107,7 @@ class FFTOperator(LinearOperator):
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
utilities
.
my_
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
...
...
nifty4/utilities.py
View file @
039a629a
...
...
@@ -23,7 +23,8 @@ import abc
from
future.utils
import
with_metaclass
__all__
=
[
"get_slice_list"
,
"safe_cast"
,
"parse_spaces"
,
"infer_space"
,
"memo"
,
"NiftyMetaBase"
,
"hartley"
,
"my_fftn_r2c"
]
"memo"
,
"NiftyMetaBase"
,
"fft_prep"
,
"hartley"
,
"my_fftn_r2c"
,
"my_fftn"
]
def
get_slice_list
(
shape
,
axes
):
...
...
@@ -167,6 +168,17 @@ def nthreads():
return
nthreads
.
_val
nthreads
.
_val
=
None
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args
=
dict
(
planner_effort
=
'FFTW_ESTIMATE'
)
def
fft_prep
():
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
def
hartley
(
a
,
axes
=
None
):
# Check if the axes provided are valid given the shape
...
...
@@ -177,7 +189,7 @@ def hartley(a, axes=None):
raise
TypeError
(
"Hartley transform requires real-valued arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
())
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
()
,
**
_fft_extra_args
)
def
_fill_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
...
...
@@ -219,7 +231,7 @@ def my_fftn_r2c(a, axes=None):
raise
TypeError
(
"Transform requires real-valued input arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
())
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
()
,
**
_fft_extra_args
)
def
_fill_complex_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
...
...
@@ -250,3 +262,8 @@ def my_fftn_r2c(a, axes=None):
return
res
return
_fill_complex_array
(
tmp
,
np
.
empty_like
(
a
,
dtype
=
tmp
.
dtype
),
axes
)
def
my_fftn
(
a
,
axes
=
None
):
from
pyfftw.interfaces.numpy_fft
import
fftn
return
fftn
(
a
,
axes
=
axes
,
**
_fft_extra_args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment