Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
3454501e
Commit
3454501e
authored
Sep 14, 2016
by
Jait Dixit
Browse files
Merge branch 'feature/field_multiple_space' into smooth_operator
parents
aaa02478
f21c8952
Changes
29
Hide whitespace changes
Inline
Side-by-side
nifty/__init__.py
View file @
3454501e
...
...
@@ -55,6 +55,8 @@ from spaces import *
from
operators
import
*
from
probing
import
*
from
demos
import
get_demo_dir
#import pyximport; pyximport.install(pyimport = True)
nifty/field.py
View file @
3454501e
...
...
@@ -4,7 +4,8 @@ import numpy as np
from
d2o
import
distributed_data_object
,
\
STRATEGIES
as
DISTRIBUTION_STRATEGIES
from
nifty.config
import
about
,
nifty_configuration
as
gc
from
nifty.config
import
about
,
\
nifty_configuration
as
gc
from
nifty.field_types
import
FieldType
...
...
@@ -489,15 +490,23 @@ class Field(object):
else
:
dtype
=
np
.
dtype
(
dtype
)
casted_x
=
self
.
_actual_cast
(
x
,
dtype
=
dtype
)
for
ind
,
sp
in
enumerate
(
self
.
domain
):
casted_x
=
sp
.
pre_cast
(
x
,
axes
=
self
.
domain_axes
[
ind
])
for
ind
,
ft
in
enumerate
(
self
.
field_type
):
casted_x
=
ft
.
pre_cast
(
casted_x
,
axes
=
self
.
field_type_axes
[
ind
])
casted_x
=
self
.
_actual_cast
(
casted_x
,
dtype
=
dtype
)
for
ind
,
sp
in
enumerate
(
self
.
domain
):
casted_x
=
sp
.
complemen
t_cast
(
casted_x
,
axes
=
self
.
domain_axes
[
ind
])
casted_x
=
sp
.
pos
t_cast
(
casted_x
,
axes
=
self
.
domain_axes
[
ind
])
for
ind
,
ft
in
enumerate
(
self
.
field_type
):
casted_x
=
ft
.
complemen
t_cast
(
casted_x
,
axes
=
self
.
field_type_axes
[
ind
])
casted_x
=
ft
.
pos
t_cast
(
casted_x
,
axes
=
self
.
field_type_axes
[
ind
])
return
casted_x
...
...
@@ -576,7 +585,7 @@ class Field(object):
new_field
.
__class__
=
self
.
__class__
# copy domain, codomain and val
for
key
,
value
in
self
.
__dict__
.
items
():
if
key
!=
'val'
:
if
key
!=
'
_
val'
:
new_field
.
__dict__
[
key
]
=
value
else
:
new_field
.
__dict__
[
key
]
=
self
.
val
.
copy_empty
()
...
...
nifty/field_types/field_type.py
View file @
3454501e
...
...
@@ -51,5 +51,8 @@ class FieldType(object):
return
result_array
def
complement_cast
(
self
,
x
,
axes
=
None
):
def
pre_cast
(
self
,
x
,
axes
=
None
):
return
x
def
post_cast
(
self
,
x
,
axes
=
None
):
return
x
nifty/operators/diagonal_operator/diagonal_operator.py
View file @
3454501e
...
...
@@ -15,12 +15,11 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
implemented
=
Fals
e
,
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
implemented
=
Tru
e
,
diagonal
=
None
,
bare
=
False
,
copy
=
True
,
distribution_strategy
=
None
):
super
(
DiagonalOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
,
implemented
=
implemented
)
field_type
=
field_type
)
self
.
_implemented
=
bool
(
implemented
)
...
...
@@ -30,61 +29,27 @@ class DiagonalOperator(EndomorphicOperator):
elif
isinstance
(
diagonal
,
Field
):
distribution_strategy
=
diagonal
.
distribution_strategy
self
.
distribution_strategy
=
self
.
_parse_distribution_strategy
(
self
.
_
distribution_strategy
=
self
.
_parse_distribution_strategy
(
distribution_strategy
=
distribution_strategy
,
val
=
diagonal
)
self
.
set_diagonal
(
diagonal
=
diagonal
,
bare
=
bare
,
copy
=
copy
)
def
_times
(
self
,
x
,
spaces
,
types
):
# if the distribution_strategy of self is sub-slice compatible to
# the one of x, reshape the local data of self and apply it directly
active_axes
=
[]
if
spaces
is
None
:
for
axes
in
x
.
domain_axes
:
active_axes
+=
axes
else
:
for
space_index
in
spaces
:
active_axes
+=
x
.
domain_axes
[
space_index
]
if
types
is
None
:
for
axes
in
x
.
field_type_axes
:
active_axes
+=
axes
else
:
for
type_index
in
types
:
active_axes
+=
x
.
field_type_axes
[
type_index
]
if
x
.
val
.
get_axes_local_distribution_strategy
(
active_axes
)
==
\
self
.
distribution_strategy
:
local_data
=
self
.
_diagonal
.
val
.
get_local_data
(
copy
=
False
)
# check if domains match completely
# -> multiply directly
# check if axes_local_distribution_strategy matches.
# If yes, extract local data of self.diagonal and x and use numpy
# reshape.
# assert that indices in spaces and types are striktly increasing
# otherwise a wild transpose would be necessary
# build new shape (1,1,x,1,y,1,1,z)
# copy self.diagonal into new shape
# apply reshaped array to x
return
self
.
_times_helper
(
x
,
spaces
,
types
,
operation
=
lambda
z
:
z
.
__mul__
)
def
_adjoint_times
(
self
,
x
,
spaces
,
types
):
pass
return
self
.
_times_helper
(
x
,
spaces
,
types
,
operation
=
lambda
z
:
z
.
adjoint
().
__mul__
)
def
_inverse_times
(
self
,
x
,
spaces
,
types
):
pass
return
self
.
_times_helper
(
x
,
spaces
,
types
,
operation
=
lambda
z
:
z
.
__rdiv__
)
def
_adjoint_inverse_times
(
self
,
x
,
spaces
,
types
):
pass
def
_inverse_adjoint_times
(
self
,
x
,
spaces
,
types
):
pass
return
self
.
_times_helper
(
x
,
spaces
,
types
,
operation
=
lambda
z
:
z
.
adjoint
().
__rdiv__
)
def
diagonal
(
self
,
bare
=
False
,
copy
=
True
):
if
bare
:
...
...
@@ -178,3 +143,49 @@ class DiagonalOperator(EndomorphicOperator):
# store the diagonal-field
self
.
_diagonal
=
f
def
_times_helper
(
self
,
x
,
spaces
,
types
,
operation
):
# if the domain and field_type match directly
# -> multiply the fields directly
if
x
.
domain
==
self
.
domain
and
x
.
field_type
==
self
.
field_type
:
# here the actual multiplication takes place
return
operation
(
self
.
diagonal
(
copy
=
False
))(
x
)
# if the distribution_strategy of self is sub-slice compatible to
# the one of x, reshape the local data of self and apply it directly
active_axes
=
[]
if
spaces
is
None
:
for
axes
in
x
.
domain_axes
:
active_axes
+=
axes
else
:
for
space_index
in
spaces
:
active_axes
+=
x
.
domain_axes
[
space_index
]
if
types
is
None
:
for
axes
in
x
.
field_type_axes
:
active_axes
+=
axes
else
:
for
type_index
in
types
:
active_axes
+=
x
.
field_type_axes
[
type_index
]
axes_local_distribution_strategy
=
\
x
.
val
.
get_axes_local_distribution_strategy
(
active_axes
)
if
axes_local_distribution_strategy
==
self
.
distribution_strategy
:
local_diagonal
=
self
.
_diagonal
.
val
.
get_local_data
(
copy
=
False
)
else
:
# create an array that is sub-slice compatible
redistr_diagonal_val
=
self
.
_diagonal
.
val
.
copy
(
distribution_strategy
=
axes_local_distribution_strategy
)
local_diagonal
=
redistr_diagonal_val
.
get_local_data
(
copy
=
False
)
reshaper
=
[
x
.
shape
[
i
]
if
i
in
active_axes
else
1
for
i
in
xrange
(
len
(
x
.
shape
))]
reshaped_local_diagonal
=
np
.
reshape
(
local_diagonal
,
reshaper
)
# here the actual multiplication takes place
local_result
=
operation
(
reshaped_local_diagonal
)(
x
.
val
.
get_local_data
(
copy
=
False
))
result_field
=
x
.
copy_empty
(
dtype
=
local_result
.
dtype
)
result_field
.
val
.
set_local_data
(
local_result
,
copy
=
False
)
return
result_field
nifty/operators/fft_operator/fft_operator.py
View file @
3454501e
from
nifty.config
import
about
import
nifty.nifty_utilities
as
utilities
from
nifty.spaces
import
RGSpace
,
\
GLSpace
,
\
HPSpace
,
\
LMSpace
from
nifty.operators.linear_operator
import
LinearOperator
from
transformations
import
TransformationFactory
from
transformations
import
RGRGTransformation
,
\
LMGLTransformation
,
\
LMHPTransformation
,
\
GLLMTransformation
,
\
HPLMTransformation
,
\
TransformationCache
class
FFTOperator
(
LinearOperator
):
# ---Class attributes---
default_codomain_dictionary
=
{
RGSpace
:
RGSpace
,
HPSpace
:
LMSpace
,
GLSpace
:
LMSpace
,
LMSpace
:
HPSpace
,
}
transformation_dictionary
=
{(
RGSpace
,
RGSpace
):
RGRGTransformation
,
(
HPSpace
,
LMSpace
):
HPLMTransformation
,
(
GLSpace
,
LMSpace
):
GLLMTransformation
,
(
LMSpace
,
HPSpace
):
LMHPTransformation
,
(
LMSpace
,
GLSpace
):
LMGLTransformation
}
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
None
):
def
__init__
(
self
,
domain
=
(),
field_type
=
(),
target
=
None
,
module
=
None
):
super
(
FFTOperator
,
self
).
__init__
(
domain
=
domain
,
field_type
=
field_type
)
# Initialize domain and target
if
len
(
self
.
domain
)
!=
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: TransformationOperator accepts only exactly one '
...
...
@@ -24,17 +51,30 @@ class FFTOperator(LinearOperator):
))
if
target
is
None
:
target
=
utilities
.
get_default_codomain
(
self
.
domain
[
0
])
target
=
(
self
.
get_default_codomain
(
self
.
domain
[
0
]),
)
self
.
_target
=
self
.
_parse_domain
(
target
)
self
.
_forward_transformation
=
TransformationFactory
.
create
(
self
.
domain
[
0
],
self
.
target
[
0
]
)
self
.
_inverse_transformation
=
TransformationFactory
.
create
(
self
.
target
[
0
],
self
.
domain
[
0
]
)
# Create transformation instances
try
:
forward_class
=
self
.
transformation_dictionary
[
(
self
.
domain
[
0
].
__class__
,
self
.
target
[
0
].
__class__
)]
except
KeyError
:
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: No forward transformation for domain-target pair "
"found."
))
try
:
backward_class
=
self
.
transformation_dictionary
[
(
self
.
target
[
0
].
__class__
,
self
.
domain
[
0
].
__class__
)]
except
KeyError
:
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: No backward transformation for domain-target pair "
"found."
))
self
.
_forward_transformation
=
TransformationCache
.
create
(
forward_class
,
self
.
domain
[
0
],
self
.
target
[
0
],
module
=
module
)
self
.
_backward_transformation
=
TransformationCache
.
create
(
backward_class
,
self
.
target
[
0
],
self
.
domain
[
0
],
module
=
module
)
def
_times
(
self
,
x
,
spaces
,
types
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
...
...
@@ -69,7 +109,7 @@ class FFTOperator(LinearOperator):
else
:
axes
=
x
.
domain_axes
[
spaces
[
0
]]
new_val
=
self
.
_
inverse
_transformation
.
transform
(
x
.
val
,
axes
=
axes
)
new_val
=
self
.
_
backward
_transformation
.
transform
(
x
.
val
,
axes
=
axes
)
if
spaces
is
None
:
result_domain
=
self
.
domain
...
...
@@ -99,3 +139,22 @@ class FFTOperator(LinearOperator):
@
property
def
unitary
(
self
):
return
True
# ---Added properties and methods---
@
classmethod
def
get_default_codomain
(
cls
,
domain
):
domain_class
=
domain
.
__class__
try
:
codomain_class
=
cls
.
default_codomain_dictionary
[
domain_class
]
except
KeyError
:
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: unknown domain"
))
try
:
transform_class
=
cls
.
transformation_dictionary
[(
domain_class
,
codomain_class
)]
except
KeyError
:
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: No transformation for domain-codomain pair found."
))
return
transform_class
.
get_codomain
(
domain
)
nifty/operators/fft_operator/transformations/__init__.py
View file @
3454501e
...
...
@@ -4,4 +4,4 @@ from hplmtransformation import HPLMTransformation
from
lmgltransformation
import
LMGLTransformation
from
lmhptransformation
import
LMHPTransformation
from
transformation_factory
import
TransformationFactory
\ No newline at end of file
from
transformation_cache
import
TransformationCache
\ No newline at end of file
nifty/operators/fft_operator/transformations/gllmtransformation.py
View file @
3454501e
import
numpy
as
np
from
transformation
import
Transformation
from
d2o
import
distributed_data_object
from
nifty.config
import
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
nifty.config
import
dependency_injector
as
gdi
,
\
about
from
nifty
import
GLSpace
,
LMSpace
from
slicing_transformation
import
SlicingTransformation
import
lm_transformation_factory
as
ltf
g
l
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
l
ibsharp
=
gdi
.
get
(
'libsharp_wrapper_gl'
)
class
GLLMTransformation
(
Transformation
):
class
GLLMTransformation
(
SlicingTransformation
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
'libsharp_wrapper_gl'
not
in
gdi
:
raise
ImportError
(
"The module libsharp is needed but not available"
)
if
codomain
is
None
:
self
.
domain
=
domain
self
.
codomain
=
self
.
get_codomain
(
domain
)
elif
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
raise
ImportError
(
about
.
_errors
.
cstring
(
"The module libsharp is needed but not available."
))
@
staticmethod
def
get_codomain
(
domain
):
super
(
GLLMTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
# ---Mandatory properties and methods---
@
classmethod
def
get_codomain
(
cls
,
domain
):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ an instance of the :py:class:`lm_space` class.
...
...
@@ -38,96 +37,89 @@ class GLLMTransformation(Transformation):
codomain : LMSpace
A compatible codomain.
"""
if
domain
is
None
:
raise
ValueError
(
'ERROR: cannot generate codomain for None'
)
if
not
isinstance
(
domain
,
GLSpace
):
raise
TypeError
(
'ERROR: domain needs to be a GLSpace'
)
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: domain needs to be a GLSpace"
))
nlat
=
domain
.
nlat
lmax
=
nlat
-
1
mmax
=
nlat
-
1
if
domain
.
dtype
==
np
.
dtype
(
'float32'
):
return
LMSpace
(
lmax
=
lmax
,
mmax
=
mmax
,
dtype
=
np
.
complex64
)
return
_dtype
=
np
.
float32
else
:
return
LMSpace
(
lmax
=
lmax
,
mmax
=
mmax
,
dtype
=
np
.
complex128
)
return_dtype
=
np
.
float64
result
=
LMSpace
(
lmax
=
lmax
,
mmax
=
mmax
,
dtype
=
return_dtype
)
cls
.
check_codomain
(
domain
,
result
)
return
result
@
staticmethod
def
check_codomain
(
domain
,
codomain
):
if
not
isinstance
(
domain
,
GLSpace
):
raise
TypeError
(
'ERROR: domain is not a GLSpace'
)
if
codomain
is
None
:
return
False
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: domain is not a GLSpace"
))
if
not
isinstance
(
codomain
,
LMSpace
):
raise
TypeError
(
'ERROR: codomain must be a LMSpace.'
)
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: codomain must be a LMSpace."
))
nlat
=
domain
.
nlat
nlon
=
domain
.
nlon
lmax
=
codomain
.
lmax
mmax
=
codomain
.
mmax
if
(
nlon
!=
2
*
nlat
-
1
)
or
(
lmax
!=
nlat
-
1
)
or
(
lmax
!=
mmax
):
return
False
if
lmax
!=
mmax
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: codomain has lmax != mmax.'
))
return
True
def
transform
(
self
,
val
,
axes
=
None
,
**
kwargs
):
"""
GL -> LM transform method.
if
lmax
!=
nlat
-
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: codomain has lmax != nlat - 1.'
))
Parameters
----------
val : np.ndarray or distributed_data_object
The value array which is to be transformed
if
nlon
!=
2
*
nlat
-
1
:
raise
ValueError
(
about
.
_errors
.
cstring
(
'ERROR: domain has nlon != 2 * nlat - 1.'
))
axes : None or tuple
The axes along which the transformation should take place
return
None
"""
if
self
.
domain
.
discrete
:
val
=
self
.
domain
.
weight
(
val
,
power
=-
0.5
,
axes
=
axes
)
# shorthands for transform parameters
def
_transformation_of_slice
(
self
,
inp
,
**
kwargs
):
nlat
=
self
.
domain
.
nlat
nlon
=
self
.
domain
.
nlon
lmax
=
self
.
codomain
.
lmax
mmax
=
self
.
codomain
.
mmax
if
isinstance
(
val
,
distributed_data_object
):
temp_val
=
val
.
get_full_data
()
else
:
temp_val
=
val
return_val
=
None
for
slice_list
in
utilities
.
get_slice_list
(
temp_val
.
shape
,
axes
):
if
slice_list
==
[
slice
(
None
,
None
)]:
inp
=
temp_val
else
:
if
return_val
is
None
:
return_val
=
np
.
empty_like
(
temp_val
)
inp
=
temp_val
[
slice_list
]
if
self
.
domain
.
dtype
==
np
.
dtype
(
'float32'
):
inp
=
gl
.
map2alm_f
(
inp
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
)
else
:
inp
=
gl
.
map2alm
(
inp
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
)
if
slice_list
==
[
slice
(
None
,
None
)]:
return_val
=
inp
else
:
return_val
[
slice_list
]
=
inp
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
self
.
codomain
.
dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
if
issubclass
(
inp
.
dtype
.
type
,
np
.
complexfloating
):
[
resultReal
,
resultImag
]
=
[
self
.
libsharpMap2Alm
(
x
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
,
**
kwargs
)
for
x
in
(
inp
.
real
,
inp
.
imag
)]
[
resultReal
,
resultImag
]
=
[
ltf
.
buildIdx
(
x
,
lmax
=
lmax
)
for
x
in
[
resultReal
,
resultImag
]]
result
=
self
.
_combine_complex_result
(
resultReal
,
resultImag
)
else
:
return_val
=
return_val
.
astype
(
self
.
codomain
.
dtype
,
copy
=
False
)
result
=
self
.
libsharpMap2Alm
(
inp
,
nlat
=
nlat
,
nlon
=
nlon
,
lmax
=
lmax
,
mmax
=
mmax
)
result
=
ltf
.
buildIdx
(
result
,
lmax
=
lmax
)
return
result
return
return_val
# ---Added properties and methods---
def
libsharpMap2Alm
(
self
,
inp
,
**
kwargs
):
if
inp
.
dtype
==
np
.
dtype
(
'float32'
):
return
libsharp
.
map2alm_f
(
inp
,
**
kwargs
)
elif
inp
.
dtype
==
np
.
dtype
(
'float64'
):
return
libsharp
.
map2alm
(
inp
,
**
kwargs
)
else
:
about
.
warnings
.
cprint
(
"WARNING: performing dtype conversion for "
"libsharp compatibility."
)
casted_inp
=
inp
.
astype
(
np
.
dtype
(
'float64'
),
copy
=
False
)
result
=
libsharp
.
map2alm
(
casted_inp
,
**
kwargs
)
return
result
\ No newline at end of file
nifty/operators/fft_operator/transformations/hplmtransformation.py
View file @
3454501e
import
numpy
as
np
from
transformation
import
Transformation
from
d2o
import
distributed_data_object
from
nifty.config
import
dependency_injector
as
gdi
import
nifty.nifty_utilities
as
utilities
from
nifty.config
import
dependency_injector
as
gdi
,
\
about
from
nifty
import
HPSpace
,
LMSpace
from
slicing_transformation
import
SlicingTransformation
import
lm_transformation_factory
as
ltf
hp
=
gdi
.
get
(
'healpy'
)
class
HPLMTransformation
(
Transformation
):
class
HPLMTransformation
(
SlicingTransformation
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
,
codomain
=
None
,
module
=
None
):
if
'healpy'
not
in
gdi
:
raise
ImportError
(
"The module healpy is needed but not available"
)
if
codomain
is
None
:
self
.
domain
=
domain
self
.
codomain
=
self
.
get_codomain
(
domain
)
elif
self
.
check_codomain
(
domain
,
codomain
):
self
.
domain
=
domain
self
.
codomain
=
codomain
else
:
raise
ValueError
(
"ERROR: Incompatible codomain!"
)
raise
ImportError
(
about
.
_errors
.
cstring
(
"The module healpy is needed but not available"
))
@
staticmethod
def
get_codomain
(
domain
):
super
(
HPLMTransformation
,
self
).
__init__
(
domain
,
codomain
,
module
)
# ---Mandatory properties and methods---
@
classmethod
def
get_codomain
(
cls
,
domain
):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ an instance of the :py:class:`lm_space` class.
...
...
@@ -38,87 +38,65 @@ class HPLMTransformation(Transformation):
codomain : LMSpace
A compatible codomain.
"""
if
domain
is
None
:
raise
ValueError
(
'ERROR: cannot generate codomain for None'
)
if
not
isinstance
(
domain
,
HPSpace
):
raise
TypeError
(
'ERROR: domain needs to be a HPSpace'
)
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: domain needs to be a HPSpace"
))
lmax
=
3
*
domain
.
nside
-
1
mmax
=
lmax
return
LMSpace
(
lmax
=
lmax
,
mmax
=
mmax
,
dtype
=
np
.
dtype
(
'complex128'
))