Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
162f4066
Commit
162f4066
authored
Aug 24, 2016
by
Theo Steininger
Browse files
Merge branch 'transformation_operator' into 'feature/field_multiple_space'
TransformationOperator See merge request
!23
parents
d57b2f02
3265ea51
Changes
19
Hide whitespace changes
Inline
Side-by-side
nifty/__init__.py
View file @
162f4066
...
...
@@ -52,11 +52,10 @@ from nifty_utilities import *
from
field_types
import
FieldType
,
\
FieldArray
from
operators
import
*
from
spaces
import
*
from
operators
import
*
from
demos
import
get_demo_dir
#import pyximport; pyximport.install(pyimport = True)
from
transformations
import
*
nifty/field.py
View file @
162f4066
...
...
@@ -40,6 +40,7 @@ class Field(object):
start
=
start
)
self
.
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
,
domain
=
self
.
domain
,
field_type
=
self
.
field_type
)
...
...
nifty/nifty_utilities.py
View file @
162f4066
...
...
@@ -41,14 +41,14 @@ def get_slice_list(shape, axes):
axes(axis) does not match shape."
)
)
axes_select
=
[
0
if
x
in
axes
else
1
for
x
,
y
in
enumerate
(
shape
)]
axes_iterables
=
\
axes_iterables
=
\
[
range
(
y
)
for
x
,
y
in
enumerate
(
shape
)
if
x
not
in
axes
]
for
index
in
product
(
*
axes_iterables
):
it_iter
=
iter
(
index
)
slice_list
=
[
next
(
it_iter
)
if
axis
else
slice
(
None
,
None
)
for
axis
in
axes_select
]
]
yield
slice_list
else
:
yield
[
slice
(
None
,
None
)]
...
...
@@ -68,7 +68,7 @@ def hermitianize_gaussian(x, axes=None):
# The fixed points of the point inversion must not be avaraged.
# Hence one must multiply them again with sqrt(0.5)
# -> Get the middle index of the array
mid_index
=
np
.
array
(
x
.
shape
,
dtype
=
np
.
int
)
//
2
mid_index
=
np
.
array
(
x
.
shape
,
dtype
=
np
.
int
)
//
2
dimensions
=
mid_index
.
size
# Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points.
...
...
@@ -78,7 +78,7 @@ def hermitianize_gaussian(x, axes=None):
ndlist
=
[
2
if
i
in
axes
else
1
for
i
in
xrange
(
dimensions
)]
ndlist
=
tuple
(
ndlist
)
for
i
in
np
.
ndindex
(
ndlist
):
temp_index
=
tuple
(
i
*
mid_index
)
temp_index
=
tuple
(
i
*
mid_index
)
x
[
temp_index
]
*=
np
.
sqrt
(
0.5
)
try
:
x
.
hermitian
=
True
...
...
@@ -109,7 +109,7 @@ def _hermitianize_inverter(x, axes):
# calculate the number of dimensions the input array has
dimensions
=
len
(
x
.
shape
)
# prepare the slicing object which will be used for mirroring
slice_primitive
=
[
slice
(
None
),
]
*
dimensions
slice_primitive
=
[
slice
(
None
),
]
*
dimensions
# copy the input data
y
=
x
.
copy
()
...
...
@@ -208,8 +208,9 @@ def field_map(ishape, function, *args):
# with ishape (3,4,3) and (3,4,1)
def
get_clipped
(
w
,
ind
):
w_shape
=
np
.
array
(
np
.
shape
(
w
))
get_tuple
=
tuple
(
np
.
clip
(
ind
,
0
,
w_shape
-
1
))
get_tuple
=
tuple
(
np
.
clip
(
ind
,
0
,
w_shape
-
1
))
return
w
[
get_tuple
]
result
=
np
.
empty_like
(
args
[
0
])
for
i
in
xrange
(
reduce
(
lambda
x
,
y
:
x
*
y
,
result
.
shape
)):
ii
=
np
.
unravel_index
(
i
,
result
.
shape
)
...
...
@@ -229,7 +230,7 @@ def cast_axis_to_tuple(axis, length):
axis
=
tuple
(
int
(
item
)
for
item
in
axis
)
except
(
TypeError
):
if
np
.
isscalar
(
axis
):
axis
=
(
int
(
axis
),
)
axis
=
(
int
(
axis
),)
else
:
raise
TypeError
(
"ERROR: Could not convert axis-input to tuple of ints"
)
...
...
@@ -242,7 +243,7 @@ def cast_axis_to_tuple(axis, length):
# assert that all entries are elements in [0, length]
for
elem
in
axis
:
assert
(
0
<=
elem
<
length
)
assert
(
0
<=
elem
<
length
)
return
axis
...
...
@@ -262,3 +263,21 @@ def complex_bincount(x, weights=None, minlength=None):
return
real_bincount
+
imag_bincount
else
:
return
x
.
bincount
(
weights
=
weights
,
minlength
=
minlength
)
def
get_default_codomain
(
domain
):
from
nifty.spaces
import
RGSpace
,
HPSpace
,
GLSpace
,
LMSpace
from
nifty.operators.fft_operator.transformations
import
RGRGTransformation
,
\
HPLMTransformation
,
GLLMTransformation
,
LMGLTransformation
if
isinstance
(
domain
,
RGSpace
):
return
RGRGTransformation
.
get_codomain
(
domain
)
elif
isinstance
(
domain
,
HPSpace
):
return
HPLMTransformation
.
get_codomain
(
domain
)
elif
isinstance
(
domain
,
GLSpace
):
return
GLLMTransformation
.
get_codomain
(
domain
)
elif
isinstance
(
domain
,
LMSpace
):
# TODO: get the preferred transformation path from config
return
LMGLTransformation
.
get_codomain
(
domain
)
else
:
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: unknown domain'
))
nifty/operators/__init__.py
View file @
162f4066
...
...
@@ -25,6 +25,8 @@ from linear_operator import LinearOperator
from
endomorphic_operator
import
EndomorphicOperator
from
fft_operator
import
*
from
nifty_operators
import
operator
,
\
diagonal_operator
,
\
power_operator
,
\
...
...
nifty/operators/endomorphic_operator/__init__.py
View file @
162f4066
# -*- coding: utf-8 -*-
from
endmorphic_operator
import
EndomorphicOperator
from
end
o
morphic_operator
import
EndomorphicOperator
nifty/operators/fft_operator/__init__.py
0 → 100644
View file @
162f4066
from
transformations
import
*
from
fft_operator
import
FFTOperator
\ No newline at end of file
nifty/operators/fft_operator/fft_operator.py
0 → 100644
View file @
162f4066
from
nifty.config
import
about
import
nifty.nifty_utilities
as
utilities
from
nifty.operators.linear_operator
import
LinearOperator
from
transformations
import
TransformationFactory
class
FFTOperator
(
LinearOperator
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
(),
field_type_target
=
(),
implemented
=
True
):
super
(
FFTOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
,
implemented
=
implemented
)
if
self
.
domain
==
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator needs a single space as '
'input domain.'
))
else
:
if
len
(
self
.
domain
)
>
1
:
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if
self
.
field_type
!=
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator field-type has to be an '
'empty tuple.'
))
# currently not sanitizing the target
self
.
_target
=
self
.
_parse_domain
(
utilities
.
get_default_codomain
(
self
.
domain
[
0
])
)
self
.
_field_type_target
=
self
.
_parse_field_type
(
field_type_target
)
if
self
.
field_type_target
!=
():
raise
TypeError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self
.
_forward_transformation
=
TransformationFactory
.
create
(
self
.
domain
[
0
],
self
.
target
[
0
]
)
self
.
_inverse_transformation
=
TransformationFactory
.
create
(
self
.
target
[
0
],
self
.
domain
[
0
]
)
def
adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
inverse_times
(
x
,
spaces
,
types
)
def
adjoint_inverse_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
times
(
x
,
spaces
,
types
)
def
inverse_adjoint_times
(
self
,
x
,
spaces
=
None
,
types
=
None
):
return
self
.
times
(
x
,
spaces
,
types
)
def
_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
return
self
.
_forward_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
def
_inverse_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
return
self
.
_inverse_transformation
.
transform
(
x
.
val
,
axes
=
spaces
)
# ---Mandatory properties and methods---
@
property
def
target
(
self
):
return
self
.
_target
@
property
def
field_type_target
(
self
):
return
self
.
_field_type_target
nifty/transformations/__init__.py
→
nifty/
operators/fft_operator/
transformations/__init__.py
View file @
162f4066
from
rgrgtransformation
import
RGRGTransformation
from
gllmtransformation
import
GLLMTransformation
from
hplmtransformation
import
HPLMTransformation
from
lmgltransformation
import
LMGLTransformation
from
lmhptransformation
import
LMHPTransformation
from
transformation_factory
import
TransformationFactory
from
transformation_factory
import
TransformationFactory
\ No newline at end of file
nifty/transformations/gllmtransformation.py
→
nifty/
operators/fft_operator/
transformations/gllmtransformation.py
View file @
162f4066
...
...
@@ -87,7 +87,7 @@ class GLLMTransformation(Transformation):
"""
if
self
.
domain
.
discrete
:
val
=
self
.
domain
.
calc_
weight
(
val
,
power
=-
0.5
)
val
=
self
.
domain
.
weight
(
val
,
power
=-
0.5
,
axes
=
axes
)
# shorthands for transform parameters
nlat
=
self
.
domain
.
paradict
[
'nlat'
]
...
...
nifty/transformations/hplmtransformation.py
→
nifty/
operators/fft_operator/
transformations/hplmtransformation.py
View file @
162f4066
...
...
@@ -85,7 +85,7 @@ class HPLMTransformation(Transformation):
niter
=
kwargs
[
'niter'
]
if
'niter'
in
kwargs
else
0
if
self
.
domain
.
discrete
:
val
=
self
.
domain
.
calc_
weight
(
val
,
power
=-
0.5
)
val
=
self
.
domain
.
weight
(
val
,
power
=-
0.5
,
axes
=
axes
)
# shorthands for transform parameters
lmax
=
self
.
codomain
.
paradict
[
'lmax'
]
...
...
nifty/transformations/lmgltransformation.py
→
nifty/
operators/fft_operator/
transformations/lmgltransformation.py
View file @
162f4066
...
...
@@ -131,7 +131,7 @@ class LMGLTransformation(Transformation):
# re-weight if discrete
if
self
.
codomain
.
discrete
:
val
=
self
.
codomain
.
calc_
weight
(
val
,
power
=
0.5
)
val
=
self
.
codomain
.
weight
(
val
,
power
=
0.5
,
axes
=
axes
)
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
...
...
nifty/transformations/lmhptransformation.py
→
nifty/
operators/fft_operator/
transformations/lmhptransformation.py
View file @
162f4066
...
...
@@ -114,7 +114,7 @@ class LMHPTransformation(Transformation):
# re-weight if discrete
if
self
.
codomain
.
discrete
:
val
=
self
.
codomain
.
calc_
weight
(
val
,
power
=
0.5
)
val
=
self
.
codomain
.
weight
(
val
,
power
=
0.5
,
axes
=
axes
)
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
...
...
nifty/transformations/rg_transforms.py
→
nifty/
operators/fft_operator/
transformations/rg_transforms.py
View file @
162f4066
...
...
@@ -359,8 +359,7 @@ class FFTW(Transform):
local_shape
=
val
.
local_shape
,
local_offset_Q
=
local_offset_Q
,
is_local
=
False
,
transform_shape
=
val
.
shape
,
# TODO: check why inp.shape doesn't work
transform_shape
=
inp
.
shape
,
**
kwargs
)
...
...
@@ -437,10 +436,6 @@ class FFTW(Transform):
val
,
axes
,
**
kwargs
)
# If domain is purely real, the result of the FFT is hermitian
if
self
.
domain
.
paradict
[
'complexity'
]
==
0
:
return_val
.
hermitian
=
True
return
return_val
...
...
@@ -636,10 +631,6 @@ class GFFT(Transform):
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
# If the values living in domain are purely real, the result of
# the fft is hermitian
if
self
.
domain
.
paradict
[
'complexity'
]
==
0
:
new_val
.
hermitian
=
True
return_val
=
new_val
else
:
return_val
=
return_val
.
astype
(
self
.
codomain
.
dtype
,
copy
=
False
)
...
...
nifty/transformations/rgrgtransformation.py
→
nifty/
operators/fft_operator/
transformations/rgrgtransformation.py
View file @
162f4066
...
...
@@ -126,13 +126,13 @@ class RGRGTransformation(Transformation):
"""
if
self
.
_transform
.
codomain
.
harmonic
:
# correct for forward fft
val
=
self
.
_transform
.
domain
.
calc_
weight
(
val
,
power
=
1
)
val
=
self
.
_transform
.
domain
.
weight
(
val
,
power
=
1
,
axes
=
axes
)
# Perform the transformation
Tval
=
self
.
_transform
.
transform
(
val
,
axes
,
**
kwargs
)
if
not
self
.
_transform
.
codomain
.
harmonic
:
# correct for inverse fft
Tval
=
self
.
_transform
.
codomain
.
calc_
weight
(
Tval
,
power
=-
1
)
Tval
=
self
.
_transform
.
codomain
.
weight
(
Tval
,
power
=-
1
,
axes
=
axes
)
return
Tval
nifty/transformations/transformation.py
→
nifty/
operators/fft_operator/
transformations/transformation.py
View file @
162f4066
File moved
nifty/transformations/transformation_factory.py
→
nifty/
operators/fft_operator/
transformations/transformation_factory.py
View file @
162f4066
File moved
nifty/operators/linear_operator/__init__.py
View file @
162f4066
# -*- coding: utf-8 -*-
from
linear_operator
import
LinearOperator
from
linear_operator_paradict
import
LinearOperatorParadict
nifty/spaces/lm_space/lm_space.py
View file @
162f4066
...
...
@@ -16,7 +16,7 @@ from nifty.config import about,\
dependency_injector
as
gdi
from
lm_space_paradict
import
LMSpaceParadict
from
nifty.nifty_power_indices
import
lm_power_indices
#
from nifty.nifty_power_indices import lm_power_indices
from
nifty.nifty_random
import
random
gl
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
...
...
test/test_nifty_transforms.py
View file @
162f4066
import
numpy
as
np
from
numpy.testing
import
assert_equal
,
assert_almost_equal
,
assert_raises
import
itertools
import
unittest
import
d2o
import
numpy
as
np
from
nifty.rg.rg_space
import
gc
as
RG_GC
from
nose_parameterized
import
parameterized
import
unittest
import
itertools
from
numpy.testing
import
assert_raises
from
nifty
import
RGSpace
,
LMSpace
,
HPSpace
,
GLSpace
from
nifty
import
transformator
from
nifty.transformations.rgrgtransformation
import
RGRGTransformation
from
nifty.rg.rg_space
import
gc
as
RG_GC
import
d2o
from
nifty.operators.fft_operator.transformations.rgrgtransformation
import
RGRGTransformation
###############################################################################
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment