Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
N
NIFTy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
13
Issues
13
List
Boards
Labels
Service Desk
Milestones
Merge Requests
13
Merge Requests
13
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
NIFTy
Commits
0d492a58
Commit
0d492a58
authored
Jul 11, 2016
by
Jait Dixit
1
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
WIP: Move transforms to transformator
parent
1cad6829
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
709 additions
and
2 deletions
+709
-2
__init__.py
__init__.py
+1
-1
setup.py
setup.py
+1
-1
transforms/__init__.py
transforms/__init__.py
+3
-0
transforms/fftw.py
transforms/fftw.py
+524
-0
transforms/gfft.py
transforms/gfft.py
+108
-0
transforms/transform.py
transforms/transform.py
+26
-0
transforms/transform_factory.py
transforms/transform_factory.py
+46
-0
No files found.
__init__.py
View file @
0d492a58
...
...
@@ -93,4 +93,4 @@ from demos import get_demo_dir
from
pickling
import
_pickle_method
,
_unpickle_method
#import pyximport; pyximport.install(pyimport = True)
from
transforms
import
tf
as
transformator
setup.py
View file @
0d492a58
...
...
@@ -34,7 +34,7 @@ setup(name="ift_nifty",
url
=
"http://www.mpa-garching.mpg.de/ift/nifty/"
,
packages
=
[
"nifty"
,
"nifty.demos"
,
"nifty.rg"
,
"nifty.lm"
,
"nifty.operators"
,
"nifty.dummys"
,
"nifty.field_types"
,
"nifty.config"
,
"nifty.power"
],
"nifty.config"
,
"nifty.power"
,
"nifty.transforms"
],
package_dir
=
{
"nifty"
:
""
},
zip_safe
=
False
,
dependency_links
=
[
...
...
transforms/__init__.py
0 → 100644
View file @
0d492a58
from
transform_factory
import
TransformFactory
tf
=
TransformFactory
()
transforms/fftw.py
0 → 100644
View file @
0d492a58
import
warnings
import
numpy
as
np
from
d2o
import
distributed_data_object
,
STRATEGIES
from
nifty.config
import
about
,
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
transform
import
FFT
pyfftw
=
gdi
.
get
(
'pyfftw'
)
class
FFTW
(
FFT
):
"""
The pyfftw pendant of a fft object.
"""
# The plan_dict stores the FFTWTransformInfo objects which correspond
# to a certain set of (field_val, domain, codomain) sets.
info_dict
=
{}
# initialize the dictionary which stores the values from
# get_centering_mask
centering_mask_dict
=
{}
def
__init__
(
self
,
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
if
'pyfftw'
not
in
gdi
:
raise
ImportError
(
"The module pyfftw is needed but not available."
)
self
.
name
=
'pyfftw'
# Enable caching for pyfftw.interfaces
pyfftw
.
interfaces
.
cache
.
enable
()
@
classmethod
def
get_centering_mask
(
cls
,
to_center_input
,
dimensions_input
,
offset_input
=
False
):
"""
Computes the mask, used to (de-)zerocenter domain and target
fields.
Parameters
----------
to_center_input : tuple, list, numpy.ndarray
A tuple of booleans which dimensions should be
zero-centered.
dimensions_input : tuple, list, numpy.ndarray
A tuple containing the mask's desired shape.
offset_input : int, boolean
Specifies whether the zero-th dimension starts with an odd
or and even index, i.e. if it is shifted.
Returns
-------
result : np.ndarray
A 1/-1-alternating mask.
"""
# cast input
to_center
=
np
.
array
(
to_center_input
)
dimensions
=
np
.
array
(
dimensions_input
)
# if none of the dimensions are zero centered, return a 1
if
np
.
all
(
to_center
==
0
):
return
1
if
np
.
all
(
dimensions
==
np
.
array
(
1
))
or
\
np
.
all
(
dimensions
==
np
.
array
([
1
])):
return
dimensions
# The dimensions of size 1 must be sorted out for computing the
# centering_mask. The depth of the array will be restored in the
# end.
size_one_dimensions
=
[]
temp_dimensions
=
[]
temp_to_center
=
[]
for
i
in
range
(
len
(
dimensions
)):
if
dimensions
[
i
]
==
1
:
size_one_dimensions
+=
[
True
]
else
:
size_one_dimensions
+=
[
False
]
temp_dimensions
+=
[
dimensions
[
i
]]
temp_to_center
+=
[
to_center
[
i
]]
dimensions
=
np
.
array
(
temp_dimensions
)
to_center
=
np
.
array
(
temp_to_center
)
# cast the offset_input into the shape of to_center
offset
=
np
.
zeros
(
to_center
.
shape
,
dtype
=
int
)
offset
[
0
]
=
int
(
offset_input
)
# check for dimension match
if
to_center
.
size
!=
dimensions
.
size
:
raise
TypeError
(
'The length of the supplied lists does not match.'
)
# build up the value memory
# compute an identifier for the parameter set
temp_id
=
tuple
(
(
tuple
(
to_center
),
tuple
(
dimensions
),
tuple
(
offset
)))
if
temp_id
not
in
cls
.
centering_mask_dict
:
# use np.tile in order to stack the core alternation scheme
# until the desired format is constructed.
core
=
np
.
fromfunction
(
lambda
*
args
:
(
-
1
)
**
(
np
.
tensordot
(
to_center
,
args
+
offset
.
reshape
(
offset
.
shape
+
(
1
,)
*
(
np
.
array
(
args
).
ndim
-
1
)),
1
)),
(
2
,)
*
to_center
.
size
)
# Cast the core to the smallest integers we can get
core
=
core
.
astype
(
np
.
int8
)
centering_mask
=
np
.
tile
(
core
,
dimensions
//
2
)
# for the dimensions of odd size corresponding slices must be
# added
for
i
in
range
(
centering_mask
.
ndim
):
# check if the size of the certain dimension is odd or even
if
(
dimensions
%
2
)[
i
]
==
0
:
continue
# prepare the slice object
temp_slice
=
(
slice
(
None
),)
*
i
+
(
slice
(
-
2
,
-
1
,
1
),)
+
\
(
slice
(
None
),)
*
(
centering_mask
.
ndim
-
1
-
i
)
# append the slice to the centering_mask
centering_mask
=
np
.
append
(
centering_mask
,
centering_mask
[
temp_slice
],
axis
=
i
)
# Add depth to the centering_mask where the length of a
# dimension was one
temp_slice
=
()
for
i
in
range
(
len
(
size_one_dimensions
)):
if
size_one_dimensions
[
i
]:
temp_slice
+=
(
None
,)
else
:
temp_slice
+=
(
slice
(
None
),)
centering_mask
=
centering_mask
[
temp_slice
]
cls
.
centering_mask_dict
[
temp_id
]
=
centering_mask
return
cls
.
centering_mask_dict
[
temp_id
]
@
classmethod
def
_get_transform_info
(
cls
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
is_local
,
transform_shape
=
None
,
**
kwargs
):
# generate a id-tuple which identifies the domain-codomain setting
temp_id
=
(
domain
.
__hash__
()
^
(
101
*
codomain
.
__hash__
())
^
(
211
*
transform_shape
.
__hash__
()))
# generate the plan_and_info object if not already there
if
temp_id
not
in
cls
.
info_dict
:
if
is_local
:
cls
.
info_dict
[
temp_id
]
=
FFTWLocalTransformInfo
(
domain
,
codomain
,
local_shape
,
local_offset_Q
,
**
kwargs
)
else
:
cls
.
info_dict
[
temp_id
]
=
FFTWMPITransfromInfo
(
domain
,
codomain
,
local_shape
,
local_offset_Q
,
transform_shape
,
**
kwargs
)
return
cls
.
info_dict
[
temp_id
]
def
_apply_mask
(
self
,
val
,
mask
,
axes
):
"""
Apply centering mask to an array.
Parameters
----------
val: distributed_data_object or numpy.ndarray
The value-array on which the mask should be applied.
mask: numpy.ndarray
The mask to be applied.
axes: tuple
The axes which are to be transformed.
Returns
-------
distributed_data_object or np.nd_array
Mask input array by multiplying it with the mask.
"""
# reshape mask if necessary
if
axes
:
mask
=
mask
.
reshape
(
[
y
if
x
in
axes
else
1
for
x
,
y
in
enumerate
(
val
.
shape
)]
)
return
val
*
mask
def
_atomic_mpi_transform
(
self
,
val
,
info
,
axes
):
# Apply codomain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
codomain
.
paradict
[
'zerocenter'
]):
temp_val
=
np
.
copy
(
val
)
val
=
self
.
_apply_mask
(
temp_val
,
info
.
cmask_codomain
,
axes
)
p
=
info
.
plan
# Load the value into the plan
if
p
.
has_input
:
p
.
input_array
[:]
=
val
# Execute the plan
p
()
if
p
.
has_output
:
result
=
p
.
output_array
else
:
return
None
# Apply domain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain
.
paradict
[
'zerocenter'
]):
result
=
self
.
_apply_mask
(
result
,
info
.
cmask_domain
,
axes
)
# Correct the sign if needed
result
*=
info
.
sign
return
result
def
_local_transform
(
self
,
val
,
axes
,
**
kwargs
):
####
# val must be numpy array or d2o with slicing distributor
###
local_offset_Q
=
False
try
:
local_val
=
val
.
get_local_data
(
copy
=
False
)
if
axes
is
None
or
0
in
axes
:
local_offset_Q
=
val
.
distributor
.
local_shape
[
0
]
%
2
except
(
AttributeError
):
local_val
=
val
current_info
=
self
.
_get_transform_info
(
self
.
domain
,
self
.
codomain
,
local_shape
=
local_val
.
shape
,
local_offset_Q
=
local_offset_Q
,
is_local
=
True
,
**
kwargs
)
# Apply codomain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
codomain
.
paradict
[
'zerocenter'
]):
temp_val
=
np
.
copy
(
local_val
)
local_val
=
self
.
_apply_mask
(
temp_val
,
current_info
.
cmask_codomain
,
axes
)
local_result
=
current_info
.
fftw_interface
(
local_val
,
axes
=
axes
,
planner_effort
=
'FFTW_ESTIMATE'
)
# Apply domain centering mask
if
reduce
(
lambda
x
,
y
:
x
+
y
,
self
.
domain
.
paradict
[
'zerocenter'
]):
local_result
=
self
.
_apply_mask
(
local_result
,
current_info
.
cmask_domain
,
axes
)
# Correct the sign if needed
if
current_info
.
sign
!=
1
:
local_result
*=
current_info
.
sign
try
:
# Create return object and insert results inplace
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
self
.
codomain
.
dtype
)
return_val
.
set_local_data
(
data
=
local_result
,
copy
=
False
)
except
(
AttributeError
):
return_val
=
local_result
return
return_val
def
_repack_to_fftw_and_transform
(
self
,
val
,
axes
,
**
kwargs
):
temp_val
=
val
.
copy_empty
(
distribution_strategy
=
'fftw'
)
about
.
warnings
.
cprint
(
'WARNING: Repacking d2o to fftw
\
distribution strategy'
)
temp_val
.
set_full_data
(
val
,
copy
=
False
)
# Recursive call to transform
result
=
self
.
transform
(
temp_val
,
axes
,
**
kwargs
)
return_val
=
result
.
copy_empty
(
distribution_strategy
=
val
.
distribution_strategy
)
return_val
.
set_full_data
(
data
=
result
,
copy
=
False
)
return
return_val
def
_mpi_transform
(
self
,
val
,
axes
,
**
kwargs
):
if
axes
is
None
or
0
in
axes
:
local_offset_list
=
np
.
cumsum
(
np
.
concatenate
([[
0
,
],
val
.
distributor
.
all_local_slices
[:,
2
]])
)
local_offset_Q
=
bool
(
local_offset_list
[
val
.
distributor
.
comm
.
rank
]
%
2
)
else
:
local_offset_Q
=
False
return_val
=
val
.
copy_empty
(
global_shape
=
val
.
shape
,
dtype
=
self
.
codomain
.
dtype
)
# Extract local data
local_val
=
val
.
get_local_data
(
copy
=
False
)
# Create temporary storage for slices
temp_val
=
None
# If axes tuple includes all axes, set it to None
if
axes
is
not
None
:
if
set
(
axes
)
==
set
(
range
(
len
(
val
.
shape
))):
axes
=
None
current_info
=
None
for
slice_list
in
utilities
.
get_slice_list
(
local_val
.
shape
,
axes
):
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
local_val
else
:
if
temp_val
is
None
:
temp_val
=
np
.
empty_like
(
local_val
)
inp
=
local_val
[
slice_list
]
# This is in order to make FFTW behave properly when slicing input
# over MPI ranks when the input is 1-dimensional. The default
# behaviour is to optimize to take advantage of byte-alignment,
# which doesn't match the slicing strategy for multi-dimensional
# data.
original_shape
=
None
if
len
(
inp
.
shape
)
==
1
:
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
inp
.
shape
[
0
],
1
)
if
current_info
is
None
:
current_info
=
self
.
_get_transform_info
(
self
.
domain
,
self
.
codomain
,
local_shape
=
val
.
local_shape
,
local_offset_Q
=
local_offset_Q
,
is_local
=
False
,
transform_shape
=
inp
.
shape
,
**
kwargs
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
result
=
self
.
_atomic_mpi_transform
(
inp
,
current_info
,
axes
)
if
result
is
None
:
temp_val
=
np
.
empty_like
(
local_val
)
elif
slice_list
==
[
slice
(
None
,
None
)]:
temp_val
=
result
else
:
# Reverting to the original shape i.e. before the input was
# augmented with 1 to make FFTW behave properly.
if
original_shape
is
not
None
:
result
=
result
.
reshape
(
original_shape
)
temp_val
[
slice_list
]
=
result
return_val
.
set_local_data
(
data
=
temp_val
,
copy
=
False
)
return
return_val
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
"""
The pyfftw transform function.
Parameters
----------
val : distributed_data_object or numpy.ndarray
The value-array of the field which is supposed to
be transformed.
domain : nifty.rg.nifty_rg.rg_space
The domain of the space which should be transformed.
codomain : nifty.rg.nifty_rg.rg_space
The target into which the field should be transformed.
axes: tuple, None
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are passed to the create_mpi_plan routine.
Returns
-------
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
not
all
(
axis
in
range
(
len
(
val
.
shape
))
for
axis
in
axes
):
raise
ValueError
(
"ERROR: Provided axes does not match array shape"
)
# If the input is a numpy array we transform it locally
if
not
isinstance
(
val
,
distributed_data_object
):
# Cast to a np.ndarray
temp_val
=
np
.
asarray
(
val
)
# local transform doesn't apply transforms inplace
return_val
=
self
.
_local_transform
(
temp_val
,
axes
)
else
:
if
val
.
distribution_strategy
in
STRATEGIES
[
'slicing'
]:
if
axes
is
None
or
0
in
axes
:
if
val
.
distribution_strategy
!=
'fftw'
:
return_val
=
\
self
.
_repack_to_fftw_and_transform
(
val
,
axes
,
**
kwargs
)
else
:
return_val
=
self
.
_mpi_transform
(
val
,
axes
,
**
kwargs
)
else
:
return_val
=
self
.
_local_transform
(
val
,
axes
,
**
kwargs
)
else
:
return_val
=
self
.
_repack_to_fftw_and_transform
(
val
,
axes
,
**
kwargs
)
# If domain is purely real, the result of the FFT is hermitian
if
self
.
domain
.
paradict
[
'complexity'
]
==
0
:
return_val
.
hermitian
=
True
return
return_val
class
FFTWTransformInfo
(
object
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
**
kwargs
):
if
pyfftw
is
None
:
raise
ImportError
(
"The module pyfftw is needed but not available."
)
self
.
cmask_domain
=
FFTW
.
get_centering_mask
(
domain
.
paradict
[
'zerocenter'
],
local_shape
,
local_offset_Q
)
self
.
cmask_codomain
=
FFTW
.
get_centering_mask
(
codomain
.
paradict
[
'zerocenter'
],
local_shape
,
local_offset_Q
)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
self
.
sign
=
(
-
1
)
**
np
.
sum
(
np
.
array
(
domain
.
paradict
[
'zerocenter'
])
*
np
.
array
(
codomain
.
paradict
[
'zerocenter'
])
*
(
np
.
array
(
domain
.
shape
)
//
2
%
2
))
@
property
def
cmask_domain
(
self
):
return
self
.
_domain_centering_mask
@
cmask_domain
.
setter
def
cmask_domain
(
self
,
cmask
):
self
.
_domain_centering_mask
=
cmask
@
property
def
cmask_codomain
(
self
):
return
self
.
_codomain_centering_mask
@
cmask_codomain
.
setter
def
cmask_codomain
(
self
,
cmask
):
self
.
_codomain_centering_mask
=
cmask
@
property
def
sign
(
self
):
return
self
.
_sign
@
sign
.
setter
def
sign
(
self
,
sign
):
self
.
_sign
=
sign
class
FFTWLocalTransformInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
**
kwargs
):
super
(
FFTWLocalTransformInfo
,
self
).
__init__
(
domain
,
codomain
,
local_shape
,
local_offset_Q
,
**
kwargs
)
if
codomain
.
harmonic
:
self
.
_fftw_interface
=
pyfftw
.
interfaces
.
numpy_fft
.
fftn
else
:
self
.
_fftw_interface
=
pyfftw
.
interfaces
.
numpy_fft
.
ifftn
@
property
def
fftw_interface
(
self
):
return
self
.
_fftw_interface
@
fftw_interface
.
setter
def
fftw_interface
(
self
,
interface
):
about
.
warnings
.
cprint
(
'WARNING: FFTWLocalTransformInfo fftw_interface
\
cannot be modified'
)
class
FFTWMPITransfromInfo
(
FFTWTransformInfo
):
def
__init__
(
self
,
domain
,
codomain
,
local_shape
,
local_offset_Q
,
transform_shape
,
**
kwargs
):
super
(
FFTWMPITransfromInfo
,
self
).
__init__
(
domain
,
codomain
,
local_shape
,
local_offset_Q
,
**
kwargs
)
self
.
_plan
=
pyfftw
.
create_mpi_plan
(
input_shape
=
transform_shape
,
input_dtype
=
'complex128'
,
output_dtype
=
'complex128'
,
direction
=
'FFTW_FORWARD'
if
codomain
.
harmonic
else
'FFTW_BACKWARD'
,
flags
=
[
"FFTW_ESTIMATE"
],
**
kwargs
)
@
property
def
plan
(
self
):
return
self
.
_plan
@
plan
.
setter
def
plan
(
self
,
plan
):
about
.
warnings
.
cprint
(
'WARNING: FFTWMPITransfromInfo plan
\
cannot be modified'
)
transforms/gfft.py
0 → 100644
View file @
0d492a58
import
numpy
as
np
from
transform
import
FFT
from
d2o
import
distributed_data_object
import
nifty.nifty_utilities
as
utilities
class
GFFT
(
FFT
):
"""
The gfft pendant of a fft object.
Parameters
----------
fft_module_name : String
Switch between the gfft module used: 'gfft' and 'gfft_dummy'
"""
def
__init__
(
self
,
domain
,
codomain
,
fft_module
):
self
.
domain
=
domain
self
.
codomain
=
codomain
self
.
fft_machine
=
fft_module
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
"""
The gfft transform function.
Parameters
----------
val : numpy.ndarray or distributed_data_object
The value-array of the field which is supposed to
be transformed.
axes : None or tuple
The axes which should be transformed.
**kwargs : *optional*
Further kwargs are not processed.
Returns
-------
result : np.ndarray or distributed_data_object
Fourier-transformed pendant of the input field.
"""
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
not
all
(
axis
in
range
(
len
(
val
.
shape
))
for
axis
in
axes
):
raise
ValueError
(
"ERROR: Provided axes does not match array shape"
)
# GFFT doesn't accept d2o objects as input. Consolidate data from
# all nodes into numpy.ndarray before proceeding.
if
isinstance
(
val
,
distributed_data_object
):
temp_inp
=
val
.
get_full_data
()
else
:
temp_inp
=
val
# Cast input datatype to codomain's dtype
temp_inp
=
temp_inp
.
astype
(
np
.
complex128
,
copy
=
False
)
# Array for storing the result
return_val
=
None
for
slice_list
in
utilities
.
get_slice_list
(
temp_inp
.
shape
,
axes
):
# don't copy the whole data array
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
temp_inp
else
:
# initialize the return_val object if needed
if
return_val
is
None
:
return_val
=
np
.
empty_like
(
temp_inp
)
inp
=
temp_inp
[
slice_list
]
inp
=
self
.
fft_machine
.
gfft
(
inp
,
in_ax
=
[],
out_ax
=
[],
ftmachine
=
'fft'
if
self
.
codomain
.
harmonic
else
'ifft'
,
in_zero_center
=
map
(
bool
,
self
.
domain
.
paradict
[
'zerocenter'
]
),
out_zero_center
=
map
(
bool
,
self
.
codomain
.
paradict
[
'zerocenter'
]
),
enforce_hermitian_symmetry
=
bool
(
self
.
codomain
.
paradict
[
'complexity'
]
),
W
=-
1
,
alpha
=-
1
,
verbose
=
False
)
if
slice_list
==
[
slice
(
None
,
None
)]:
return_val
=
inp
else
:
return_val
[
slice_list
]
=
inp
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
# If the values living in domain are purely real, the result of
# the fft is hermitian