Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
90d6e2f7
Commit
90d6e2f7
authored
Oct 25, 2016
by
Theo Steininger
Browse files
Merge branch 'fix_fftw_transform' into 'master'
Fix fftw transform See merge request
!35
parents
72fa347b
2b070f69
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty/minimization/line_searching/line_search.py
View file @
90d6e2f7
...
@@ -7,7 +7,7 @@ from nifty import LineEnergy
...
@@ -7,7 +7,7 @@ from nifty import LineEnergy
class
LineSearch
(
object
,
Loggable
):
class
LineSearch
(
object
,
Loggable
):
"""
"""
Class for finding a step size.
◙
Class for finding a step size.
"""
"""
__metaclass__
=
abc
.
ABCMeta
__metaclass__
=
abc
.
ABCMeta
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
90d6e2f7
...
@@ -164,24 +164,26 @@ class FFTW(Transform):
...
@@ -164,24 +164,26 @@ class FFTW(Transform):
self
.
centering_mask_dict
[
temp_id
]
=
centering_mask
self
.
centering_mask_dict
[
temp_id
]
=
centering_mask
return
self
.
centering_mask_dict
[
temp_id
]
return
self
.
centering_mask_dict
[
temp_id
]
def
_get_transform_info
(
self
,
domain
,
codomain
,
local_shape
,
def
_get_transform_info
(
self
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
is_local
,
transform_shape
=
None
,
local_offset_Q
,
is_local
,
transform_shape
=
None
,
**
kwargs
):
**
kwargs
):
# generate a id-tuple which identifies the domain-codomain setting
# generate a id-tuple which identifies the domain-codomain setting
temp_id
=
(
domain
.
__hash__
()
^
temp_id
=
(
domain
.
__hash__
()
^
(
101
*
codomain
.
__hash__
())
^
(
101
*
codomain
.
__hash__
())
^
(
211
*
transform_shape
.
__hash__
()))
(
211
*
transform_shape
.
__hash__
())
^
(
131
*
is_local
.
__hash__
())
)
# generate the plan_and_info object if not already there
# generate the plan_and_info object if not already there
if
temp_id
not
in
self
.
info_dict
:
if
temp_id
not
in
self
.
info_dict
:
if
is_local
:
if
is_local
:
self
.
info_dict
[
temp_id
]
=
FFTWLocalTransformInfo
(
self
.
info_dict
[
temp_id
]
=
FFTWLocalTransformInfo
(
domain
,
codomain
,
local_shape
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
self
,
**
kwargs
local_offset_Q
,
self
,
**
kwargs
)
)
else
:
else
:
self
.
info_dict
[
temp_id
]
=
FFTWMPITransfromInfo
(
self
.
info_dict
[
temp_id
]
=
FFTWMPITransfromInfo
(
domain
,
codomain
,
local_shape
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
self
,
transform_shape
,
**
kwargs
local_offset_Q
,
self
,
transform_shape
,
**
kwargs
)
)
...
@@ -248,17 +250,16 @@ class FFTW(Transform):
...
@@ -248,17 +250,16 @@ class FFTW(Transform):
# val must be numpy array or d2o with slicing distributor
# val must be numpy array or d2o with slicing distributor
###
###
local_offset_Q
=
False
try
:
try
:
local_val
=
val
.
get_local_data
(
copy
=
False
)
local_val
=
val
.
get_local_data
(
copy
=
False
)
if
axes
is
None
or
0
in
axes
:
local_offset_Q
=
val
.
distributor
.
local_shape
[
0
]
%
2
except
(
AttributeError
):
except
(
AttributeError
):
local_val
=
val
local_val
=
val
current_info
=
self
.
_get_transform_info
(
self
.
domain
,
current_info
=
self
.
_get_transform_info
(
self
.
domain
,
self
.
codomain
,
self
.
codomain
,
axes
,
local_shape
=
local_val
.
shape
,
local_shape
=
local_val
.
shape
,
local_offset_Q
=
local_offset_Q
,
local_offset_Q
=
False
,
is_local
=
True
,
is_local
=
True
,
**
kwargs
)
**
kwargs
)
...
@@ -309,14 +310,10 @@ class FFTW(Transform):
...
@@ -309,14 +310,10 @@ class FFTW(Transform):
def
_mpi_transform
(
self
,
val
,
axes
,
**
kwargs
):
def
_mpi_transform
(
self
,
val
,
axes
,
**
kwargs
):
if
axes
is
None
or
0
in
axes
:
local_offset_list
=
np
.
cumsum
(
local_offset_list
=
np
.
cumsum
(
np
.
concatenate
([[
0
,
],
val
.
distributor
.
all_local_slices
[:,
2
]])
np
.
concatenate
([[
0
,
],
val
.
distributor
.
all_local_slices
[:,
2
]])
)
)
local_offset_Q
=
bool
(
local_offset_list
[
val
.
distributor
.
comm
.
rank
]
%
2
)
local_offset_Q
=
bool
(
local_offset_list
[
val
.
distributor
.
comm
.
rank
]
%
2
)
else
:
local_offset_Q
=
False
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
self
.
codomain
.
dtype
)
dtype
=
self
.
codomain
.
dtype
)
...
@@ -362,6 +359,7 @@ class FFTW(Transform):
...
@@ -362,6 +359,7 @@ class FFTW(Transform):
current_info
=
self
.
_get_transform_info
(
current_info
=
self
.
_get_transform_info
(
self
.
domain
,
self
.
domain
,
self
.
codomain
,
self
.
codomain
,
axes
,
local_shape
=
val
.
local_shape
,
local_shape
=
val
.
local_shape
,
local_offset_Q
=
local_offset_Q
,
local_offset_Q
=
local_offset_Q
,
is_local
=
False
,
is_local
=
False
,
...
@@ -446,20 +444,22 @@ class FFTW(Transform):
...
@@ -446,20 +444,22 @@ class FFTW(Transform):
class
FFTWTransformInfo
(
object
):
class
FFTWTransformInfo
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
def
__init__
(
self
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
local_offset_Q
,
fftw_context
,
**
kwargs
):
if
pyfftw
is
None
:
if
pyfftw
is
None
:
raise
ImportError
(
"The module pyfftw is needed but not available."
)
raise
ImportError
(
"The module pyfftw is needed but not available."
)
self
.
cmask_domain
=
fftw_context
.
get_centering_mask
(
shape
=
(
local_shape
if
axes
is
None
else
domain
.
zerocenter
,
[
y
for
x
,
y
in
enumerate
(
local_shape
)
if
x
in
axes
])
local_shape
,
local_offset_Q
)
self
.
cmask_domain
=
fftw_context
.
get_centering_mask
(
domain
.
zerocenter
,
shape
,
local_offset_Q
)
self
.
cmask_codomain
=
fftw_context
.
get_centering_mask
(
self
.
cmask_codomain
=
fftw_context
.
get_centering_mask
(
codomain
.
zerocenter
,
codomain
.
zerocenter
,
local_
shape
,
shape
,
local_offset_Q
)
local_offset_Q
)
# If both domain and codomain are zero-centered the result,
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
# will get a global minus. Store the sign to correct it.
...
@@ -493,10 +493,11 @@ class FFTWTransformInfo(object):
...
@@ -493,10 +493,11 @@ class FFTWTransformInfo(object):
class
FFTWLocalTransformInfo
(
FFTWTransformInfo
):
class
FFTWLocalTransformInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
def
__init__
(
self
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
fftw_context
,
**
kwargs
):
local_offset_Q
,
fftw_context
,
**
kwargs
):
super
(
FFTWLocalTransformInfo
,
self
).
__init__
(
domain
,
super
(
FFTWLocalTransformInfo
,
self
).
__init__
(
domain
,
codomain
,
codomain
,
axes
,
local_shape
,
local_shape
,
local_offset_Q
,
local_offset_Q
,
fftw_context
,
fftw_context
,
...
@@ -512,10 +513,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
...
@@ -512,10 +513,11 @@ class FFTWLocalTransformInfo(FFTWTransformInfo):
class
FFTWMPITransfromInfo
(
FFTWTransformInfo
):
class
FFTWMPITransfromInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
def
__init__
(
self
,
domain
,
codomain
,
axes
,
local_shape
,
local_offset_Q
,
fftw_context
,
transform_shape
,
**
kwargs
):
local_offset_Q
,
fftw_context
,
transform_shape
,
**
kwargs
):
super
(
FFTWMPITransfromInfo
,
self
).
__init__
(
domain
,
super
(
FFTWMPITransfromInfo
,
self
).
__init__
(
domain
,
codomain
,
codomain
,
axes
,
local_shape
,
local_shape
,
local_offset_Q
,
local_offset_Q
,
fftw_context
,
fftw_context
,
...
...
nifty/operators/fft_operator/transformations/rgrgtransformation.py
View file @
90d6e2f7
...
@@ -7,29 +7,29 @@ from nifty import RGSpace, nifty_configuration
...
@@ -7,29 +7,29 @@ from nifty import RGSpace, nifty_configuration
class
RGRGTransformation
(
Transformation
):
class
RGRGTransformation
(
Transformation
):
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
super
(
RGRGTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
super
(
RGRGTransformation
,
self
).
__init__
(
domain
,
codomain
)
if
module
is
None
:
if
module
is
None
:
if
nifty_configuration
[
'fft_module'
]
==
'pyfftw'
:
if
nifty_configuration
[
'fft_module'
]
==
'pyfftw'
:
self
.
_transform
=
FFTW
(
domain
,
codomain
)
self
.
_transform
=
FFTW
(
self
.
domain
,
self
.
codomain
)
elif
(
nifty_configuration
[
'fft_module'
]
==
'gfft'
or
elif
(
nifty_configuration
[
'fft_module'
]
==
'gfft'
or
nifty_configuration
[
'fft_module'
]
==
'gfft_dummy'
):
nifty_configuration
[
'fft_module'
]
==
'gfft_dummy'
):
self
.
_transform
=
\
self
.
_transform
=
\
GFFT
(
domain
,
GFFT
(
self
.
domain
,
codomain
,
self
.
codomain
,
gdi
.
get
(
nifty_configuration
[
'fft_module'
]))
gdi
.
get
(
nifty_configuration
[
'fft_module'
]))
else
:
else
:
raise
ValueError
(
'ERROR: unknow default FFT module:'
+
raise
ValueError
(
'ERROR: unknow default FFT module:'
+
nifty_configuration
[
'fft_module'
])
nifty_configuration
[
'fft_module'
])
else
:
else
:
if
module
==
'pyfftw'
:
if
module
==
'pyfftw'
:
self
.
_transform
=
FFTW
(
domain
,
codomain
)
self
.
_transform
=
FFTW
(
self
.
domain
,
self
.
codomain
)
elif
module
==
'gfft'
:
elif
module
==
'gfft'
:
self
.
_transform
=
\
self
.
_transform
=
\
GFFT
(
domain
,
codomain
,
gdi
.
get
(
'gfft'
))
GFFT
(
self
.
domain
,
self
.
codomain
,
gdi
.
get
(
'gfft'
))
elif
module
==
'gfft_dummy'
:
elif
module
==
'gfft_dummy'
:
self
.
_transform
=
\
self
.
_transform
=
\
GFFT
(
domain
,
codomain
,
gdi
.
get
(
'gfft_dummy'
))
GFFT
(
self
.
domain
,
self
.
codomain
,
gdi
.
get
(
'gfft_dummy'
))
else
:
else
:
raise
ValueError
(
'ERROR: unknow FFT module:'
+
module
)
raise
ValueError
(
'ERROR: unknow FFT module:'
+
module
)
...
...
nifty/operators/fft_operator/transformations/transformation.py
View file @
90d6e2f7
...
@@ -11,7 +11,7 @@ class Transformation(object, Loggable):
...
@@ -11,7 +11,7 @@ class Transformation(object, Loggable):
"""
"""
__metaclass__
=
abc
.
ABCMeta
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
def
__init__
(
self
,
domain
,
codomain
):
if
codomain
is
None
:
if
codomain
is
None
:
self
.
domain
=
domain
self
.
domain
=
domain
self
.
codomain
=
self
.
get_codomain
(
domain
)
self
.
codomain
=
self
.
get_codomain
(
domain
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment