Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
ift
NIFTy
Commits
039a629a
Commit
039a629a
authored
May 17, 2018
by
Martin Reinecke
Browse files
streamline the FFT interface
parent
5ed8f324
Pipeline
#29434
passed with stages
in 11 minutes and 52 seconds
Changes
2
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
nifty4/operators/fft_operator.py
View file @
039a629a
...
...
@@ -61,9 +61,7 @@ class FFTOperator(LinearOperator):
adom
.
check_codomain
(
target
)
target
.
check_codomain
(
adom
)
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
utilities
.
fft_prep
()
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
@@ -74,7 +72,6 @@ class FFTOperator(LinearOperator):
return
self
.
_apply_cartesian
(
x
,
mode
)
def
_apply_cartesian
(
self
,
x
,
mode
):
from
pyfftw.interfaces.numpy_fft
import
fftn
axes
=
x
.
domain
.
axes
[
self
.
_space
]
tdom
=
self
.
_target
if
x
.
domain
==
self
.
_domain
else
self
.
_domain
oldax
=
dobj
.
distaxis
(
x
.
val
)
...
...
@@ -110,7 +107,7 @@ class FFTOperator(LinearOperator):
tmp
=
dobj
.
from_local_data
(
shp2d
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
ldat2
=
dobj
.
local_data
(
tmp
)
ldat2
=
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
utilities
.
my_
fftn
(
ldat2
,
axes
=
(
1
,))
ldat2
=
ldat2
.
real
+
ldat2
.
imag
tmp
=
dobj
.
from_local_data
(
tmp
.
shape
,
ldat2
,
distaxis
=
0
)
tmp
=
dobj
.
transpose
(
tmp
)
...
...
nifty4/utilities.py
View file @
039a629a
...
...
@@ -23,7 +23,8 @@ import abc
from
future.utils
import
with_metaclass
__all__
=
[
"get_slice_list"
,
"safe_cast"
,
"parse_spaces"
,
"infer_space"
,
"memo"
,
"NiftyMetaBase"
,
"hartley"
,
"my_fftn_r2c"
]
"memo"
,
"NiftyMetaBase"
,
"fft_prep"
,
"hartley"
,
"my_fftn_r2c"
,
"my_fftn"
]
def
get_slice_list
(
shape
,
axes
):
...
...
@@ -167,6 +168,17 @@ def nthreads():
return
nthreads
.
_val
nthreads
.
_val
=
None
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args
=
dict
(
planner_effort
=
'FFTW_ESTIMATE'
)
def
fft_prep
():
import
pyfftw
pyfftw
.
interfaces
.
cache
.
enable
()
pyfftw
.
interfaces
.
cache
.
set_keepalive_time
(
1000.
)
def
hartley
(
a
,
axes
=
None
):
# Check if the axes provided are valid given the shape
...
...
@@ -177,7 +189,7 @@ def hartley(a, axes=None):
raise
TypeError
(
"Hartley transform requires real-valued arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
())
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
()
,
**
_fft_extra_args
)
def
_fill_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
...
...
@@ -219,7 +231,7 @@ def my_fftn_r2c(a, axes=None):
raise
TypeError
(
"Transform requires real-valued input arrays."
)
from
pyfftw.interfaces.numpy_fft
import
rfftn
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
())
tmp
=
rfftn
(
a
,
axes
=
axes
,
threads
=
nthreads
()
,
**
_fft_extra_args
)
def
_fill_complex_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
...
...
@@ -250,3 +262,8 @@ def my_fftn_r2c(a, axes=None):
return
res
return
_fill_complex_array
(
tmp
,
np
.
empty_like
(
a
,
dtype
=
tmp
.
dtype
),
axes
)
def
my_fftn
(
a
,
axes
=
None
):
from
pyfftw.interfaces.numpy_fft
import
fftn
return
fftn
(
a
,
axes
=
axes
,
**
_fft_extra_args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment