Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
162f4066
Commit
162f4066
authored
Aug 24, 2016
by
Theo Steininger
Browse files
Merge branch 'transformation_operator' into 'feature/field_multiple_space'
TransformationOperator See merge request
!23
parents
d57b2f02
3265ea51
Changes
19
Hide whitespace changes
Inline
Side-by-side
nifty/__init__.py
View file @
162f4066
...
@@ -52,11 +52,10 @@ from nifty_utilities import *
...
@@ -52,11 +52,10 @@ from nifty_utilities import *
from
field_types
import
FieldType
,
\
from
field_types
import
FieldType
,
\
FieldArray
FieldArray
from
operators
import
*
from
spaces
import
*
from
spaces
import
*
from
operators
import
*
from
demos
import
get_demo_dir
from
demos
import
get_demo_dir
#import pyximport; pyximport.install(pyimport = True)
#import pyximport; pyximport.install(pyimport = True)
from
transformations
import
*
nifty/field.py
View file @
162f4066
...
@@ -40,6 +40,7 @@ class Field(object):
...
@@ -40,6 +40,7 @@ class Field(object):
start
=
start
)
start
=
start
)
self
.
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
self
.
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
,
domain
=
self
.
domain
,
domain
=
self
.
domain
,
field_type
=
self
.
field_type
)
field_type
=
self
.
field_type
)
...
...
nifty/nifty_utilities.py
View file @
162f4066
...
@@ -41,14 +41,14 @@ def get_slice_list(shape, axes):
...
@@ -41,14 +41,14 @@ def get_slice_list(shape, axes):
axes(axis) does not match shape."
)
axes(axis) does not match shape."
)
)
)
axes_select
=
[
0
if
x
in
axes
else
1
for
x
,
y
in
enumerate
(
shape
)]
axes_select
=
[
0
if
x
in
axes
else
1
for
x
,
y
in
enumerate
(
shape
)]
axes_iterables
=
\
axes_iterables
=
\
[
range
(
y
)
for
x
,
y
in
enumerate
(
shape
)
if
x
not
in
axes
]
[
range
(
y
)
for
x
,
y
in
enumerate
(
shape
)
if
x
not
in
axes
]
for
index
in
product
(
*
axes_iterables
):
for
index
in
product
(
*
axes_iterables
):
it_iter
=
iter
(
index
)
it_iter
=
iter
(
index
)
slice_list
=
[
slice_list
=
[
next
(
it_iter
)
next
(
it_iter
)
if
axis
else
slice
(
None
,
None
)
for
axis
in
axes_select
if
axis
else
slice
(
None
,
None
)
for
axis
in
axes_select
]
]
yield
slice_list
yield
slice_list
else
:
else
:
yield
[
slice
(
None
,
None
)]
yield
[
slice
(
None
,
None
)]
...
@@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None):
...
@@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None):
# The fixed points of the point inversion must not be avaraged.
# The fixed points of the point inversion must not be avaraged.
# Hence one must multiply them again with sqrt(0.5)
# Hence one must multiply them again with sqrt(0.5)
# -> Get the middle index of the array
# -> Get the middle index of the array
mid_index
=
np
.
array
(
x
.
shape
,
dtype
=
np
.
int
)
//
2
mid_index
=
np
.
array
(
x
.
shape
,
dtype
=
np
.
int
)
//
2
dimensions
=
mid_index
.
size
dimensions
=
mid_index
.
size
# Use ndindex to iterate over all combinations of zeros and the
# Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points.
# mid_index in order to correct all fixed points.
...
@@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None):
...
@@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None):
ndlist
=
[
2
if
i
in
axes
else
1
for
i
in
xrange
(
dimensions
)]
ndlist
=
[
2
if
i
in
axes
else
1
for
i
in
xrange
(
dimensions
)]
ndlist
=
tuple
(
ndlist
)
ndlist
=
tuple
(
ndlist
)
for
i
in
np
.
ndindex
(
ndlist
):
for
i
in
np
.
ndindex
(
ndlist
):
temp_index
=
tuple
(
i
*
mid_index
)
temp_index
=
tuple
(
i
*
mid_index
)
x
[
temp_index
]
*=
np
.
sqrt
(
0.5
)
x
[
temp_index
]
*=
np
.
sqrt
(
0.5
)
try
:
try
:
x
.
hermitian
=
True
x
.
hermitian
=
True
...
@@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes):
...
@@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes):
# calculate the number of dimensions the input array has
# calculate the number of dimensions the input array has
dimensions
=
len
(
x
.
shape
)
dimensions
=
len
(
x
.
shape
)
# prepare the slicing object which will be used for mirroring
# prepare the slicing object which will be used for mirroring
slice_primitive
=
[
slice
(
None
),
]
*
dimensions
slice_primitive
=
[
slice
(
None
),
]
*
dimensions
# copy the input data
# copy the input data
y
=
x
.
copy
()
y
=
x
.
copy
()
...
@@ -208,8 +208,9 @@ def field_map(ishape, function, *args):
...
@@ -208,8 +208,9 @@ def field_map(ishape, function, *args):
# with ishape (3,4,3) and (3,4,1)
# with ishape (3,4,3) and (3,4,1)
def
get_clipped
(
w
,
ind
):
def
get_clipped
(
w
,
ind
):
w_shape
=
np
.
array
(
np
.
shape
(
w
))
w_shape
=
np
.
array
(
np
.
shape
(
w
))
get_tuple
=
tuple
(
np
.
clip
(
ind
,
0
,
w_shape
-
1
))
get_tuple
=
tuple
(
np
.
clip
(
ind
,
0
,
w_shape
-
1
))
return
w
[
get_tuple
]
return
w
[
get_tuple
]
result
=
np
.
empty_like
(
args
[
0
])
result
=
np
.
empty_like
(
args
[
0
])
for
i
in
xrange
(
reduce
(
lambda
x
,
y
:
x
*
y
,
result
.
shape
)):
for
i
in
xrange
(
reduce
(
lambda
x
,
y
:
x
*
y
,
result
.
shape
)):
ii
=
np
.
unravel_index
(
i
,
result
.
shape
)
ii
=
np
.
unravel_index
(
i
,
result
.
shape
)
...
@@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length):
...
@@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length):
axis
=
tuple
(
int
(
item
)
for
item
in
axis
)
axis
=
tuple
(
int
(
item
)
for
item
in
axis
)
except
(
TypeError
):
except
(
TypeError
):
if
np
.
isscalar
(
axis
):
if
np
.
isscalar
(
axis
):
axis
=
(
int
(
axis
),
)
axis
=
(
int
(
axis
),)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"ERROR: Could not convert axis-input to tuple of ints"
)
"ERROR: Could not convert axis-input to tuple of ints"
)
...
@@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length):
...
@@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length):
# assert that all entries are elements in [0, length]
# assert that all entries are elements in [0, length]
for
elem
in
axis
:
for
elem
in
axis
:
assert
(
0
<=
elem
<
length
)
assert
(
0
<=
elem
<
length
)
return
axis
return
axis
...
@@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None):
...
@@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None):
return
real_bincount
+
imag_bincount
return
real_bincount
+
imag_bincount
else
:
else
:
return
x
.
bincount
(
weights
=
weights
,
minlength
=
minlength
)
return
x
.
bincount
(
weights
=
weights
,
minlength
=
minlength
)
def
get_default_codomain
(
domain
):
from
nifty.spaces
import
RGSpace
,
HPSpace
,
GLSpace
,
LMSpace
from
nifty.operators.fft_operator.transformations
import
RGRGTransformation
,
\
HPLMTransformation
,
GLLMTransformation
,
LMGLTransformation
if
isinstance
(
domain
,
RGSpace
):
return
RGRGTransformation
.
get_codomain
(
domain
)
elif
isinstance
(
domain
,
HPSpace
):
return
HPLMTransformation
.
get_codomain
(
domain
)
elif
isinstance
(
domain
,
GLSpace
):
return
GLLMTransformation
.
get_codomain
(
domain
)
elif
isinstance
(
domain
,
LMSpace
):
# TODO: get the preferred transformation path from config
return
LMGLTransformation
.
get_codomain
(
domain
)
else
:
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: unknown domain'
))
nifty/operators/__init__.py
View file @
162f4066
...
@@ -25,6 +25,8 @@ from linear_operator import LinearOperator
...
@@ -25,6 +25,8 @@ from linear_operator import LinearOperator
from
endomorphic_operator
import
EndomorphicOperator
from
endomorphic_operator
import
EndomorphicOperator
from
fft_operator
import
*
from
nifty_operators
import
operator
,
\
from
nifty_operators
import
operator
,
\
diagonal_operator
,
\
diagonal_operator
,
\
power_operator
,
\
power_operator
,
\
...
...
nifty/operators/endomorphic_operator/__init__.py
View file @
162f4066
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
endmorphic_operator
import
EndomorphicOperator
from
end
o
morphic_operator
import
EndomorphicOperator
nifty/operators/fft_operator/__init__.py
0 → 100644
View file @
162f4066
from
transformations
import
*
from
fft_operator
import
FFTOperator
\ No newline at end of file
nifty/operators/fft_operator/fft_operator.py
0 → 100644
View file @
162f4066
from
nifty.config
import
about
import
nifty.nifty_utilities
as
utilities
from
nifty.operators.linear_operator
import
LinearOperator
from
transformations
import
TransformationFactory
class
FFTOperator
(
LinearOperator
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
(),
field_type_target
=
(),
implemented
=
True
):
super
(
FFTOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
,
implemented
=
implemented
)
if
self
.
domain
==
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator needs a single space as '
'input domain.'
))
else
:
if
len
(
self
.
domain
)
>
1
:
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if
self
.
field_type
!=
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator field-type has to be an '
'empty tuple.'
))
# currently not sanitizing the target
self
.
_target
=
self
.
_parse_domain
(
utilities
.
get_default_codomain
(
self
.
domain
[
0
])
)
self
.
_field_type_target
=
self
.
_parse_field_type
(
field_type_target
)
if
self
.
field_type_target
!=
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self
.
_forward_transformation
=
TransformationFactory
.
create
(
self
.
domain
[
0
],
self
.
target
[
0
]
)
self
.
_inverse_transformation
=
TransformationFactory
.
create
(
self
.
target
[
0
],
self
.
domain
[
0
]
)
def
adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
inverse_times
(
x
,
spaces
,
types
)
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
times
(
x
,
spaces
,
types
)
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
times
(
x
,
spaces
,
types
)
def
_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
return
self
.
_forward_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
def
_inverse_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
return
self
.
_inverse_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
# ---Mandatory properties and methods---
@
property
def
target
(
self
):
return
self
.
_target
@
property
def
field_type_target
(
self
):
return
self
.
_field_type_target
nifty/transformations/__init__.py
→
nifty/
operators/fft_operator/
transformations/__init__.py
View file @
162f4066
from
rgrgtransformation
import
RGRGTransformation
from
rgrgtransformation
import
RGRGTransformation
from
gllmtransformation
import
GLLMTransformation
from
gllmtransformation
import
GLLMTransformation
from
hplmtransformation
import
HPLMTransformation
from
hplmtransformation
import
HPLMTransformation
from
lmgltransformation
import
LMGLTransformation
from
lmgltransformation
import
LMGLTransformation
from
lmhptransformation
import
LMHPTransformation
from
lmhptransformation
import
LMHPTransformation
from
transformation_factory
import
TransformationFactory
from
transformation_factory
import
TransformationFactory
\ No newline at end of file
nifty/transformations/gllmtransformation.py
→
nifty/
operators/fft_operator/
transformations/gllmtransformation.py
View file @
162f4066
...
@@ -87,7 +87,7 @@ class GLLMTransformation(Transformation):
...
@@ -87,7 +87,7 @@ class GLLMTransformation(Transformation):
"""
"""
if
self
.
domain
.
discrete
:
if
self
.
domain
.
discrete
:
val
=
self
.
domain
.
calc_
weight
(
val
,
power
=-
0.5
)
val
=
self
.
domain
.
weight
(
val
,
power
=-
0.5
,
axes
=
axes
)
# shorthands for transform parameters
# shorthands for transform parameters
nlat
=
self
.
domain
.
paradict
[
'nlat'
]
nlat
=
self
.
domain
.
paradict
[
'nlat'
]
...
...
nifty/transformations/hplmtransformation.py
→
nifty/
operators/fft_operator/
transformations/hplmtransformation.py
View file @
162f4066
...
@@ -85,7 +85,7 @@ class HPLMTransformation(Transformation):
...
@@ -85,7 +85,7 @@ class HPLMTransformation(Transformation):
niter
=
kwargs
[
'niter'
]
if
'niter'
in
kwargs
else
0
niter
=
kwargs
[
'niter'
]
if
'niter'
in
kwargs
else
0
if
self
.
domain
.
discrete
:
if
self
.
domain
.
discrete
:
val
=
self
.
domain
.
calc_
weight
(
val
,
power
=-
0.5
)
val
=
self
.
domain
.
weight
(
val
,
power
=-
0.5
,
axes
=
axes
)
# shorthands for transform parameters
# shorthands for transform parameters
lmax
=
self
.
codomain
.
paradict
[
'lmax'
]
lmax
=
self
.
codomain
.
paradict
[
'lmax'
]
...
...
nifty/transformations/lmgltransformation.py
→
nifty/
operators/fft_operator/
transformations/lmgltransformation.py
View file @
162f4066
...
@@ -131,7 +131,7 @@ class LMGLTransformation(Transformation):
...
@@ -131,7 +131,7 @@ class LMGLTransformation(Transformation):
# re-weight if discrete
# re-weight if discrete
if
self
.
codomain
.
discrete
:
if
self
.
codomain
.
discrete
:
val
=
self
.
codomain
.
calc_
weight
(
val
,
power
=
0.5
)
val
=
self
.
codomain
.
weight
(
val
,
power
=
0.5
,
axes
=
axes
)
if
isinstance
(
val
,
distributed_data_object
):
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
...
...
nifty/transformations/lmhptransformation.py
→
nifty/
operators/fft_operator/
transformations/lmhptransformation.py
View file @
162f4066
...
@@ -114,7 +114,7 @@ class LMHPTransformation(Transformation):
...
@@ -114,7 +114,7 @@ class LMHPTransformation(Transformation):
# re-weight if discrete
# re-weight if discrete
if
self
.
codomain
.
discrete
:
if
self
.
codomain
.
discrete
:
val
=
self
.
codomain
.
calc_
weight
(
val
,
power
=
0.5
)
val
=
self
.
codomain
.
weight
(
val
,
power
=
0.5
,
axes
=
axes
)
if
isinstance
(
val
,
distributed_data_object
):
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
...
...
nifty/transformations/rg_transforms.py
→
nifty/
operators/fft_operator/
transformations/rg_transforms.py
View file @
162f4066
...
@@ -359,8 +359,7 @@ class FFTW(Transform):
...
@@ -359,8 +359,7 @@ class FFTW(Transform):
local_shape
=
val
.
local_shape
,
local_shape
=
val
.
local_shape
,
local_offset_Q
=
local_offset_Q
,
local_offset_Q
=
local_offset_Q
,
is_local
=
False
,
is_local
=
False
,
transform_shape
=
val
.
shape
,
transform_shape
=
inp
.
shape
,
# TODO: check why inp.shape doesn't work
**
kwargs
**
kwargs
)
)
...
@@ -437,10 +436,6 @@ class FFTW(Transform):
...
@@ -437,10 +436,6 @@ class FFTW(Transform):
val
,
axes
,
**
kwargs
val
,
axes
,
**
kwargs
)
)
# If domain is purely real, the result of the FFT is hermitian
if
self
.
domain
.
paradict
[
'complexity'
]
==
0
:
return_val
.
hermitian
=
True
return
return_val
return
return_val
...
@@ -636,10 +631,6 @@ class GFFT(Transform):
...
@@ -636,10 +631,6 @@ class GFFT(Transform):
if
isinstance
(
val
,
distributed_data_object
):
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if
self
.
domain
.
paradict
[
'complexity'
]
==
0
:
new_val
.
hermitian
=
True
return_val
=
new_val
return_val
=
new_val
else
:
else
:
return_val
=
return_val
.
astype
(
self
.
codomain
.
dtype
,
copy
=
False
)
return_val
=
return_val
.
astype
(
self
.
codomain
.
dtype
,
copy
=
False
)
...
...
nifty/transformations/rgrgtransformation.py
→
nifty/
operators/fft_operator/
transformations/rgrgtransformation.py
View file @
162f4066
...
@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation):
...
@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation):
"""
"""
if
self
.
_transform
.
codomain
.
harmonic
:
if
self
.
_transform
.
codomain
.
harmonic
:
# correct for forward fft
# correct for forward fft
val
=
self
.
_transform
.
domain
.
calc_
weight
(
val
,
power
=
1
)
val
=
self
.
_transform
.
domain
.
weight
(
val
,
power
=
1
,
axes
=
axes
)
# Perform the transformation
# Perform the transformation
Tval
=
self
.
_transform
.
transform
(
val
,
axes
,
**
kwargs
)
Tval
=
self
.
_transform
.
transform
(
val
,
axes
,
**
kwargs
)
if
not
self
.
_transform
.
codomain
.
harmonic
:
if
not
self
.
_transform
.
codomain
.
harmonic
:
# correct for inverse fft
# correct for inverse fft
Tval
=
self
.
_transform
.
codomain
.
calc_
weight
(
Tval
,
power
=-
1
)
Tval
=
self
.
_transform
.
codomain
.
weight
(
Tval
,
power
=-
1
,
axes
=
axes
)
return
Tval
return
Tval
nifty/transformations/transformation.py
→
nifty/
operators/fft_operator/
transformations/transformation.py
View file @
162f4066
File moved
nifty/transformations/transformation_factory.py
→
nifty/
operators/fft_operator/
transformations/transformation_factory.py
View file @
162f4066
File moved
nifty/operators/linear_operator/__init__.py
View file @
162f4066
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
linear_operator
import
LinearOperator
from
linear_operator
import
LinearOperator
from
linear_operator_paradict
import
LinearOperatorParadict
nifty/spaces/lm_space/lm_space.py
View file @
162f4066
...
@@ -16,7 +16,7 @@ from nifty.config import about,\
...
@@ -16,7 +16,7 @@ from nifty.config import about,\
dependency_injector
as
gdi
dependency_injector
as
gdi
from
lm_space_paradict
import
LMSpaceParadict
from
lm_space_paradict
import
LMSpaceParadict
from
nifty.nifty_power_indices
import
lm_power_indices
#
from nifty.nifty_power_indices import lm_power_indices
from
nifty.nifty_random
import
random
from
nifty.nifty_random
import
random
gl
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
gl
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
...
...
test/test_nifty_transforms.py
View file @
162f4066
import
numpy
as
np
import
itertools
from
numpy.testing
import
assert_equal
,
assert_almost_equal
,
assert_raises
import
unittest
import
d2o
import
numpy
as
np
from
nifty.rg.rg_space
import
gc
as
RG_GC
from
nose_parameterized
import
parameterized
from
nose_parameterized
import
parameterized
import
unittest
from
numpy.testing
import
assert_raises
import
itertools
from
nifty
import
RGSpace
,
LMSpace
,
HPSpace
,
GLSpace
from
nifty
import
RGSpace
,
LMSpace
,
HPSpace
,
GLSpace
from
nifty
import
transformator
from
nifty
import
transformator
from
nifty.transformations.rgrgtransformation
import
RGRGTransformation
from
nifty.operators.fft_operator.transformations.rgrgtransformation
import
RGRGTransformation
from
nifty.rg.rg_space
import
gc
as
RG_GC
import
d2o
###############################################################################
###############################################################################
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment