Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
762f85d7
Commit
762f85d7
authored
May 07, 2016
by
Jait Dixit
Browse files
WIP Clean up code
- Reformat code according to PEP8 - Add 'axes' keyword to transforms of both FFT machines.
parent
8c6068e7
Pipeline
#2426
skipped
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
rg/nifty_fft.py
View file @
762f85d7
...
...
@@ -57,6 +57,7 @@ class fft(object):
----------
None
"""
def
__init__
(
self
):
pass
...
...
@@ -168,12 +169,12 @@ class fft_fftw(fft):
# until the desired format is constructed.
core
=
np
.
fromfunction
(
lambda
*
args
:
(
-
1
)
**
(
np
.
tensordot
(
to_center
,
args
+
offset
.
reshape
(
offset
.
shape
+
(
1
,)
*
(
np
.
array
(
args
).
ndim
-
1
)),
1
)),
(
np
.
tensordot
(
to_center
,
args
+
offset
.
reshape
(
offset
.
shape
+
(
1
,)
*
(
np
.
array
(
args
).
ndim
-
1
)),
1
)),
(
2
,)
*
to_center
.
size
)
# Cast the core to the smallest integers we can get
core
=
core
.
astype
(
np
.
int8
)
...
...
@@ -186,7 +187,7 @@ class fft_fftw(fft):
if
(
dimensions
%
2
)[
i
]
==
0
:
continue
# prepare the slice object
temp_slice
=
(
slice
(
None
),)
*
i
+
(
slice
(
-
2
,
-
1
,
1
),)
+
\
temp_slice
=
(
slice
(
None
),)
*
i
+
(
slice
(
-
2
,
-
1
,
1
),)
+
\
(
slice
(
None
),)
*
(
centering_mask
.
ndim
-
1
-
i
)
# append the slice to the centering_mask
centering_mask
=
np
.
append
(
centering_mask
,
...
...
@@ -206,14 +207,14 @@ class fft_fftw(fft):
def
_get_plan_and_info
(
self
,
domain
,
codomain
,
**
kwargs
):
# generate a id-tuple which identifies the domain-codomain setting
temp_id
=
domain
.
__hash__
()
^
(
101
*
codomain
.
__hash__
())
temp_id
=
domain
.
__hash__
()
^
(
101
*
codomain
.
__hash__
())
# generate the plan_and_info object if not already there
if
temp_id
not
in
self
.
plan_dict
:
self
.
plan_dict
[
temp_id
]
=
_fftw_plan_and_info
(
domain
,
codomain
,
self
,
**
kwargs
)
return
self
.
plan_dict
[
temp_id
]
def
transform
(
self
,
val
,
domain
,
codomain
,
**
kwargs
):
def
transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
"""
The pyfftw transform function.
...
...
@@ -227,7 +228,10 @@ class fft_fftw(fft):
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
The target into which the field should be transformed.
axes: tuple, None
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are passed to the create_mpi_plan routine.
...
...
@@ -248,8 +252,8 @@ class fft_fftw(fft):
# Case 1: val is a distributed_data_object
if
isinstance
(
val
,
distributed_data_object
):
return_val
=
val
.
copy_empty
(
global_shape
=
current_plan_and_info
.
global_output_shape
,
dtype
=
codomain
.
dtype
)
global_shape
=
current_plan_and_info
.
global_output_shape
,
dtype
=
codomain
.
dtype
)
# If the distribution strategy of the d2o is fftw, extract
# the data directly
if
val
.
distribution_strategy
==
'fftw'
:
...
...
@@ -272,11 +276,11 @@ class fft_fftw(fft):
p
()
if
p
.
has_output
:
result
=
p
.
output_array
*
current_plan_and_info
.
\
result
=
p
.
output_array
*
current_plan_and_info
.
\
get_domain_centering_mask
()
else
:
result
=
local_val
assert
(
result
.
shape
[
0
]
==
0
)
assert
(
result
.
shape
[
0
]
==
0
)
# build the return object according to the input val
# TODO: Check if comm is the same, too!
...
...
@@ -314,7 +318,6 @@ class fft_fftw(fft):
# The instances of plan_and_info store the fftw plan and all
# other information needed in order to perform a mpi-fftw transformation
class
_fftw_plan_and_info
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
fft_fftw_context
,
**
kwargs
):
if
pyfftw
is
None
:
raise
ImportError
(
"The module pyfftw is needed but not available."
)
...
...
@@ -352,11 +355,11 @@ class _fftw_plan_and_info(object):
self
.
in_zero_centered_dimensions
=
domain
.
paradict
[
'zerocenter'
]
self
.
out_zero_centered_dimensions
=
codomain
.
paradict
[
'zerocenter'
]
self
.
overall_sign
=
(
-
1
)
**
np
.
sum
(
np
.
array
(
self
.
in_zero_centered_dimensions
)
*
np
.
array
(
self
.
out_zero_centered_dimensions
)
*
(
np
.
array
(
self
.
global_input_shape
)
//
2
%
2
)
)
self
.
overall_sign
=
(
-
1
)
**
np
.
sum
(
np
.
array
(
self
.
in_zero_centered_dimensions
)
*
np
.
array
(
self
.
out_zero_centered_dimensions
)
*
(
np
.
array
(
self
.
global_input_shape
)
//
2
%
2
)
)
self
.
local_node_dimensions
=
np
.
append
((
self
.
fftw_local_size
[
1
],),
self
.
global_input_shape
[
1
:])
...
...
@@ -400,13 +403,14 @@ class fft_gfft(fft):
None
"""
def
__init__
(
self
,
fft_module_name
):
self
.
fft_machine
=
gdi
.
get
(
fft_module_name
)
if
self
.
fft_machine
is
None
:
raise
ImportError
(
"The gfft(_dummy)-module is needed but not available."
)
def
transform
(
self
,
val
,
domain
,
codomain
,
**
kwargs
):
def
transform
(
self
,
val
,
domain
,
codomain
,
axes
,
**
kwargs
):
"""
The gfft transform function.
...
...
@@ -420,7 +424,7 @@ class fft_gfft(fft):
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The taget into which the field should be transformed.
The ta
r
get into which the field should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
...
...
@@ -443,36 +447,36 @@ class fft_gfft(fft):
d2oQ
=
False
temp
=
val
# transform and return
if
(
domain
.
dtype
==
np
.
float64
):
if
(
domain
.
dtype
==
np
.
float64
):
temp
=
self
.
fft_machine
.
gfft
(
temp
.
astype
(
np
.
complex128
),
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
temp
.
astype
(
np
.
complex128
),
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
else
:
temp
=
self
.
fft_machine
.
gfft
(
temp
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
temp
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
ftmachine
,
in_zero_center
=
domain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
out_zero_center
=
codomain
.
para
[
-
naxes
:].
astype
(
np
.
bool
).
tolist
(),
enforce_hermitian_symmetry
=
bool
(
codomain
.
para
[
naxes
]
==
1
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
if
d2oQ
:
new_val
=
val
.
copy_empty
(
dtype
=
np
.
complex128
)
new_val
.
set_full_data
(
temp
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment