Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
c31c7198
Commit
c31c7198
authored
Mar 22, 2017
by
Theo Steininger
Browse files
Merge branch 'master' into tests
parents
edab13c6
f8377deb
Pipeline
#10843
failed with stage
in 22 minutes and 5 seconds
Changes
19
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/domain_object.py
View file @
c31c7198
...
...
@@ -48,6 +48,11 @@ class DomainObject(Versionable, Loggable, object):
raise
NotImplementedError
(
"There is no generic dim for DomainObject."
)
@
abc
.
abstractmethod
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
raise
NotImplementedError
(
"There is no generic weight-method for DomainObject."
)
def
pre_cast
(
self
,
x
,
axes
=
None
):
return
x
...
...
nifty/field.py
View file @
c31c7198
...
...
@@ -90,7 +90,7 @@ class Field(Loggable, Versionable, object):
elif
isinstance
(
val
,
Field
):
distribution_strategy
=
val
.
distribution_strategy
else
:
self
.
logger
.
info
(
"Datamodel
set to default!"
)
self
.
logger
.
debug
(
"distribution_strategy
set to default!"
)
distribution_strategy
=
gc
[
'default_distribution_strategy'
]
elif
distribution_strategy
not
in
DISTRIBUTION_STRATEGIES
[
'global'
]:
raise
ValueError
(
...
...
@@ -151,11 +151,11 @@ class Field(Loggable, Versionable, object):
def
power_analyze
(
self
,
spaces
=
None
,
log
=
False
,
nbin
=
None
,
binbounds
=
None
,
real_signal
=
True
):
#
assert that
all spaces in `self.domain` are either harmonic or
#
check if
all spaces in `self.domain` are either harmonic or
# power_space instances
for
sp
in
self
.
domain
:
if
not
sp
.
harmonic
and
not
isinstance
(
sp
,
PowerSpace
):
raise
AttributeError
(
self
.
logger
.
info
(
"Field has a space in `domain` which is neither "
"harmonic nor a PowerSpace."
)
...
...
@@ -287,11 +287,12 @@ class Field(Loggable, Versionable, object):
def
power_synthesize
(
self
,
spaces
=
None
,
real_signal
=
True
,
mean
=
None
,
std
=
None
):
# assert that all spaces in `self.domain` are either of signal-type or
# check if all spaces in `self.domain` are either of signal-type or
# power_space instances
for
sp
in
self
.
domain
:
if
not
sp
.
harmonic
and
not
isinstance
(
sp
,
PowerSpace
):
raise
AttributeError
(
self
.
logger
.
info
(
"Field has a space in `domain` which is neither "
"harmonic nor a PowerSpace."
)
...
...
@@ -347,7 +348,8 @@ class Field(Loggable, Versionable, object):
if
real_signal
:
result_val_list
=
[
harmonic_domain
.
hermitian_decomposition
(
x
.
val
,
axes
=
x
.
domain_axes
[
power_space_index
])[
0
]
axes
=
x
.
domain_axes
[
power_space_index
],
preserve_gaussian_variance
=
True
)[
0
]
for
x
in
result_list
]
else
:
result_val_list
=
[
x
.
val
for
x
in
result_list
]
...
...
@@ -556,10 +558,9 @@ class Field(Loggable, Versionable, object):
new_val
=
self
.
get_val
(
copy
=
False
)
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
else
:
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
for
ind
,
sp
in
enumerate
(
self
.
domain
):
if
ind
in
spaces
:
...
...
@@ -571,17 +572,11 @@ class Field(Loggable, Versionable, object):
new_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
new_field
def
dot
(
self
,
x
=
None
,
bare
=
False
):
if
isinstance
(
x
,
Field
):
try
:
assert
len
(
x
.
domain
)
==
len
(
self
.
domain
)
for
index
in
xrange
(
len
(
self
.
domain
)):
assert
x
.
domain
[
index
]
==
self
.
domain
[
index
]
except
AssertionError
:
raise
ValueError
(
"domains are incompatible."
)
# extract the data from x and try to dot with this
x
=
x
.
get_val
(
copy
=
False
)
def
dot
(
self
,
x
=
None
,
spaces
=
None
,
bare
=
False
):
if
not
isinstance
(
x
,
Field
):
raise
ValueError
(
"The dot-partner must be an instance of "
+
"the NIFTy field class"
)
# Compute the dot respecting the fact of discrete/continous spaces
if
bare
:
...
...
@@ -589,14 +584,21 @@ class Field(Loggable, Versionable, object):
else
:
y
=
self
.
weight
(
power
=
1
)
y
=
y
.
get_val
(
copy
=
False
)
# Cast the input in order to cure dtype and shape differences
x
=
self
.
cast
(
x
)
dotted
=
x
.
conjugate
()
*
y
return
dotted
.
sum
()
if
spaces
is
None
:
x_val
=
x
.
get_val
(
copy
=
False
)
y_val
=
y
.
get_val
(
copy
=
False
)
result
=
(
x_val
.
conjugate
()
*
y_val
).
sum
()
return
result
else
:
# create a diagonal operator which is capable of taking care of the
# axes-matching
from
nifty.operators.diagonal_operator
import
DiagonalOperator
diagonal
=
y
.
val
.
conjugate
()
diagonalOperator
=
DiagonalOperator
(
domain
=
y
.
domain
,
diagonal
=
diagonal
,
copy
=
False
)
dotted
=
diagonalOperator
(
x
,
spaces
=
spaces
)
return
dotted
.
sum
(
spaces
=
spaces
)
def
norm
(
self
,
q
=
2
):
"""
...
...
@@ -834,7 +836,10 @@ class Field(Loggable, Versionable, object):
hdf5_group
.
attrs
[
'domain_axes'
]
=
str
(
self
.
domain_axes
)
hdf5_group
[
'num_domain'
]
=
len
(
self
.
domain
)
ret_dict
=
{
'val'
:
self
.
val
}
if
self
.
_val
is
None
:
ret_dict
=
{}
else
:
ret_dict
=
{
'val'
:
self
.
val
}
for
i
in
range
(
len
(
self
.
domain
)):
ret_dict
[
's_'
+
str
(
i
)]
=
self
.
domain
[
i
]
...
...
@@ -854,7 +859,12 @@ class Field(Loggable, Versionable, object):
new_field
.
domain
=
tuple
(
temp_domain
)
exec
(
'new_field.domain_axes = '
+
hdf5_group
.
attrs
[
'domain_axes'
])
new_field
.
_val
=
repository
.
get
(
'val'
,
hdf5_group
)
try
:
new_field
.
_val
=
repository
.
get
(
'val'
,
hdf5_group
)
except
(
KeyError
):
new_field
.
_val
=
None
new_field
.
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
new_field
.
distribution_strategy
=
\
hdf5_group
.
attrs
[
'distribution_strategy'
]
...
...
nifty/field_types/field_array.py
View file @
c31c7198
# -*- coding: utf-8 -*-
import
pickle
import
numpy
as
np
from
field_type
import
FieldType
...
...
@@ -29,14 +27,14 @@ class FieldArray(FieldType):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'shape'
]
=
self
.
shape
hdf5_group
[
'dtype'
]
=
pickle
.
dumps
(
self
.
dtype
)
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
loopback_get
):
result
=
cls
(
hdf5_group
[
'shape'
][:],
pickle
.
loads
(
hdf5_group
[
'dtype'
]
[()]
)
shape
=
hdf5_group
[
'shape'
][:],
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
nifty/field_types/field_type.py
View file @
c31c7198
...
...
@@ -5,6 +5,13 @@ from nifty.domain_object import DomainObject
class
FieldType
(
DomainObject
):
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
if
inplace
:
result
=
x
else
:
result
=
x
.
copy
()
return
result
def
process
(
self
,
method_name
,
array
,
inplace
=
True
,
**
kwargs
):
try
:
result_array
=
self
.
__getattr__
(
method_name
)(
array
,
...
...
nifty/operators/__init__.py
View file @
c31c7198
...
...
@@ -33,6 +33,8 @@ from fft_operator import *
from
invertible_operator_mixin
import
InvertibleOperatorMixin
from
projection_operator
import
ProjectionOperator
from
propagator_operator
import
PropagatorOperator
from
composed_operator
import
ComposedOperator
nifty/operators/diagonal_operator/diagonal_operator.py
View file @
c31c7198
...
...
@@ -90,10 +90,15 @@ class DiagonalOperator(EndomorphicOperator):
@
property
def
symmetric
(
self
):
if
self
.
_symmetric
is
None
:
self
.
_symmetric
=
(
self
.
_diagonal
.
val
.
imag
==
0
).
all
()
return
self
.
_symmetric
@
property
def
unitary
(
self
):
if
self
.
_unitary
is
None
:
self
.
_unitary
=
(
self
.
_diagonal
.
val
*
self
.
_diagonal
.
val
.
conjugate
()
==
1
).
all
()
return
self
.
_unitary
# ---Added properties and methods---
...
...
@@ -134,11 +139,11 @@ class DiagonalOperator(EndomorphicOperator):
# Otherwise, inplace weightening would change the external field
f
.
weight
(
inplace
=
copy
,
power
=-
1
)
#
check if the operator is symmetric
:
self
.
_symmetric
=
(
f
.
val
.
imag
==
0
).
all
()
#
Reset the symmetric property
:
self
.
_symmetric
=
None
#
check if the operator is unitary:
self
.
_unitary
=
(
f
.
val
*
f
.
val
.
conjugate
()
==
1
).
all
()
#
Reset the unitarity property
self
.
_unitary
=
None
# store the diagonal-field
self
.
_diagonal
=
f
...
...
@@ -154,9 +159,7 @@ class DiagonalOperator(EndomorphicOperator):
# the one of x, reshape the local data of self and apply it directly
active_axes
=
[]
if
spaces
is
None
:
if
self
.
domain
!=
():
for
axes
in
x
.
domain_axes
:
active_axes
+=
axes
active_axes
=
range
(
len
(
x
.
shape
))
else
:
for
space_index
in
spaces
:
active_axes
+=
x
.
domain_axes
[
space_index
]
...
...
@@ -167,6 +170,8 @@ class DiagonalOperator(EndomorphicOperator):
local_diagonal
=
self
.
_diagonal
.
val
.
get_local_data
(
copy
=
False
)
else
:
# create an array that is sub-slice compatible
self
.
logger
.
warn
(
"The input field is not sub-slice compatible to "
"the distribution strategy of the operator."
)
redistr_diagonal_val
=
self
.
_diagonal
.
val
.
copy
(
distribution_strategy
=
axes_local_distribution_strategy
)
local_diagonal
=
redistr_diagonal_val
.
get_local_data
(
copy
=
False
)
...
...
nifty/operators/fft_operator/fft_operator.py
View file @
c31c7198
import
numpy
as
np
import
nifty.nifty_utilities
as
utilities
from
nifty.spaces
import
RGSpace
,
\
GLSpace
,
\
...
...
@@ -33,7 +35,8 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def
__init__
(
self
,
domain
=
(),
target
=
None
,
module
=
None
):
def
__init__
(
self
,
domain
=
(),
target
=
None
,
module
=
None
,
domain_dtype
=
None
,
target_dtype
=
None
):
self
.
_domain
=
self
.
_parse_domain
(
domain
)
...
...
@@ -69,7 +72,18 @@ class FFTOperator(LinearOperator):
self
.
_backward_transformation
=
TransformationCache
.
create
(
backward_class
,
self
.
target
[
0
],
self
.
domain
[
0
],
module
=
module
)
def
_times
(
self
,
x
,
spaces
,
dtype
=
None
):
# Store the dtype information
if
domain_dtype
is
None
:
self
.
domain_dtype
=
None
else
:
self
.
domain_dtype
=
np
.
dtype
(
domain_dtype
)
if
target_dtype
is
None
:
self
.
target_dtype
=
None
else
:
self
.
target_dtype
=
np
.
dtype
(
target_dtype
)
def
_times
(
self
,
x
,
spaces
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
if
spaces
is
None
:
# this case means that x lives on only one space, which is
...
...
@@ -87,12 +101,13 @@ class FFTOperator(LinearOperator):
result_domain
=
list
(
x
.
domain
)
result_domain
[
spaces
[
0
]]
=
self
.
target
[
0
]
result_field
=
x
.
copy_empty
(
domain
=
result_domain
,
dtype
=
dtype
)
result_field
=
x
.
copy_empty
(
domain
=
result_domain
,
dtype
=
self
.
target_dtype
)
result_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
result_field
def
_inverse_times
(
self
,
x
,
spaces
,
dtype
=
None
):
def
_inverse_times
(
self
,
x
,
spaces
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
if
spaces
is
None
:
# this case means that x lives on only one space, which is
...
...
@@ -110,7 +125,8 @@ class FFTOperator(LinearOperator):
result_domain
=
list
(
x
.
domain
)
result_domain
[
spaces
[
0
]]
=
self
.
domain
[
0
]
result_field
=
x
.
copy_empty
(
domain
=
result_domain
,
dtype
=
dtype
)
result_field
=
x
.
copy_empty
(
domain
=
result_domain
,
dtype
=
self
.
domain_dtype
)
result_field
.
set_val
(
new_val
=
new_val
,
copy
=
False
)
return
result_field
...
...
nifty/operators/fft_operator/transformations/rg_transforms.py
View file @
c31c7198
...
...
@@ -592,6 +592,7 @@ class GFFT(Transform):
# Array for storing the result
return_val
=
None
result_dtype
=
np
.
result_type
(
np
.
complex
,
self
.
codomain
.
dtype
)
for
slice_list
in
utilities
.
get_slice_list
(
temp_inp
.
shape
,
axes
):
...
...
@@ -601,7 +602,7 @@ class GFFT(Transform):
else
:
# initialize the return_val object if needed
if
return_val
is
None
:
return_val
=
np
.
empty_like
(
temp_inp
)
return_val
=
np
.
empty_like
(
temp_inp
,
dtype
=
result_dtype
)
inp
=
temp_inp
[
slice_list
]
inp
=
self
.
fft_machine
.
gfft
(
...
...
@@ -622,12 +623,11 @@ class GFFT(Transform):
else
:
return_val
[
slice_list
]
=
inp
result_dtype
=
np
.
result_type
(
np
.
complex
,
self
.
codomain
.
dtype
)
if
isinstance
(
val
,
distributed_data_object
):
new_val
=
val
.
copy_empty
(
dtype
=
result_dtype
)
new_val
.
set_full_data
(
return_val
,
copy
=
False
)
return_val
=
new_val
else
:
return_val
=
return_val
.
astype
(
result_type
,
copy
=
False
)
return_val
=
return_val
.
astype
(
result_
d
type
,
copy
=
False
)
return
return_val
nifty/operators/linear_operator/linear_operator.py
View file @
c31c7198
...
...
@@ -127,7 +127,7 @@ class LinearOperator(Loggable, object):
self_domain
=
self
.
target
if
spaces
is
None
:
if
self_domain
!=
()
and
self_domain
!=
x
.
domain
:
if
self_domain
!=
x
.
domain
:
raise
ValueError
(
"The operator's and and field's domains don't "
"match."
)
...
...
nifty/operators/projection_operator/__init__.py
0 → 100644
View file @
c31c7198
# -*- coding: utf-8 -*-
from
projection_operator
import
ProjectionOperator
nifty/operators/projection_operator/projection_operator.py
0 → 100644
View file @
c31c7198
# -*- coding: utf-8 -*-
import
numpy
as
np
from
nifty.field
import
Field
from
nifty.operators.endomorphic_operator
import
EndomorphicOperator
class
ProjectionOperator
(
EndomorphicOperator
):
# ---Overwritten properties and methods---
def
__init__
(
self
,
projection_field
):
if
not
isinstance
(
projection_field
,
Field
):
raise
TypeError
(
"The projection_field must be a NIFTy-Field"
"instance."
)
self
.
_projection_field
=
projection_field
self
.
_unitary
=
None
def
_times
(
self
,
x
,
spaces
):
# if the domain matches directly
# -> multiply the fields directly
if
x
.
domain
==
self
.
domain
:
# here the actual multiplication takes place
dotted
=
(
self
.
_projection_field
*
x
).
sum
()
return
self
.
_projection_field
*
dotted
# if the distribution_strategy of self is sub-slice compatible to
# the one of x, reshape the local data of self and apply it directly
active_axes
=
[]
if
spaces
is
None
:
active_axes
=
range
(
len
(
x
.
shape
))
else
:
for
space_index
in
spaces
:
active_axes
+=
x
.
domain_axes
[
space_index
]
axes_local_distribution_strategy
=
\
x
.
val
.
get_axes_local_distribution_strategy
(
active_axes
)
if
axes_local_distribution_strategy
==
\
self
.
_projection_field
.
distribution_strategy
:
local_projection_vector
=
\
self
.
_projection_field
.
val
.
get_local_data
(
copy
=
False
)
else
:
# create an array that is sub-slice compatible
self
.
logger
.
warn
(
"The input field is not sub-slice compatible to "
"the distribution strategy of the operator. "
"Performing an probably expensive "
"redistribution."
)
redistr_projection_val
=
self
.
_projection_field
.
val
.
copy
(
distribution_strategy
=
axes_local_distribution_strategy
)
local_projection_vector
=
\
redistr_projection_val
.
get_local_data
(
copy
=
False
)
local_x
=
x
.
val
.
get_local_data
(
copy
=
False
)
l
=
len
(
local_projection_vector
.
shape
)
sublist_projector
=
range
(
l
)
sublist_x
=
np
.
arange
(
len
(
local_x
.
shape
))
+
l
for
i
in
xrange
(
l
):
a
=
active_axes
[
i
]
sublist_x
[
a
]
=
i
dotted
=
np
.
einsum
(
local_projection_vector
,
sublist_projector
,
local_x
,
sublist_x
)
# get those elements from sublist_x that haven't got contracted
sublist_dotted
=
sublist_x
[
sublist_x
>=
l
]
remultiplied
=
np
.
einsum
(
local_projection_vector
,
sublist_projector
,
dotted
,
sublist_dotted
,
sublist_x
)
result_field
=
x
.
copy_empty
(
dtype
=
remultiplied
.
dtype
)
result_field
.
val
.
set_local_data
(
remultiplied
,
copy
=
False
)
return
result_field
def
_inverse_times
(
self
,
x
,
spaces
):
raise
NotImplementedError
(
"The ProjectionOperator is a singular "
"operator and therefore has no inverse."
)
# ---Mandatory properties and methods---
@
property
def
domain
(
self
):
return
self
.
_projection_field
.
domain
@
property
def
implemented
(
self
):
return
True
@
property
def
unitary
(
self
):
if
self
.
_unitary
is
None
:
self
.
_unitary
=
(
self
.
_projection_field
.
val
==
1
).
all
()
return
self
.
_unitary
@
property
def
symmetric
(
self
):
return
True
nifty/spaces/gl_space/gl_space.py
View file @
c31c7198
...
...
@@ -131,10 +131,9 @@ class GLSpace(Space):
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
nlon
=
self
.
nlon
nlat
=
self
.
nlat
vol
=
gl
.
vol
(
nlat
)
**
power
weight
=
np
.
array
(
list
(
itertools
.
chain
.
from_iterable
(
itertools
.
repeat
(
x
**
power
,
nlon
)
for
x
in
gl
.
vol
(
nlat
))))
itertools
.
repeat
(
x
,
nlon
)
for
x
in
vol
)))
if
axes
is
not
None
:
# reshape the weight array to match the input shape
...
...
@@ -209,7 +208,7 @@ class GLSpace(Space):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'nlat'
]
=
self
.
nlat
hdf5_group
[
'nlon'
]
=
self
.
nlon
hdf5_group
[
'dtype'
]
=
self
.
dtype
.
name
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
...
...
@@ -218,7 +217,7 @@ class GLSpace(Space):
result
=
cls
(
nlat
=
hdf5_group
[
'nlat'
][()],
nlon
=
hdf5_group
[
'nlon'
][()],
dtype
=
np
.
dtype
(
hdf5_group
[
'dtype'
]
[()]
)
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
nifty/spaces/hp_space/hp_space.py
View file @
c31c7198
...
...
@@ -209,13 +209,13 @@ class HPSpace(Space):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'nside'
]
=
self
.
nside
hdf5_group
[
'dtype'
]
=
self
.
dtype
.
name
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
repository
):
result
=
cls
(
nside
=
hdf5_group
[
'nside'
][()],
dtype
=
np
.
dtype
(
hdf5_group
[
'dtype'
]
[()]
)
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
nifty/spaces/lm_space/lm_space.py
View file @
c31c7198
...
...
@@ -111,7 +111,8 @@ class LMSpace(Space):
super
(
LMSpace
,
self
).
__init__
(
dtype
)
self
.
_lmax
=
self
.
_parse_lmax
(
lmax
)
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
):
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
,
preserve_gaussian_variance
=
False
):
hermitian_part
=
x
.
copy_empty
()
anti_hermitian_part
=
x
.
copy_empty
()
hermitian_part
[:]
=
x
.
real
...
...
@@ -145,7 +146,6 @@ class LMSpace(Space):
def
copy
(
self
):
return
self
.
__class__
(
lmax
=
self
.
lmax
,
mmax
=
self
.
mmax
,
dtype
=
self
.
dtype
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
...
...
@@ -193,13 +193,13 @@ class LMSpace(Space):
def
_to_hdf5
(
self
,
hdf5_group
):
hdf5_group
[
'lmax'
]
=
self
.
lmax
hdf5_group
[
'dtype'
]
=
self
.
dtype
.
name
hdf5_group
.
attrs
[
'dtype'
]
=
self
.
dtype
.
name
return
None
@
classmethod
def
_from_hdf5
(
cls
,
hdf5_group
,
repository
):
result
=
cls
(
lmax
=
hdf5_group
[
'lmax'
][()],
dtype
=
np
.
dtype
(
hdf5_group
[
'dtype'
]
[()]
)
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
)
return
result
nifty/spaces/power_space/power_space.py
View file @
c31c7198
...
...
@@ -73,7 +73,7 @@ class PowerSpace(Space):
@
property
def
total_volume
(
self
):
# every power-pixel has a volume of 1
return
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
pindex
.
shape
)
return
float
(
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
pindex
.
shape
)
)
def
copy
(
self
):
distribution_strategy
=
self
.
pindex
.
distribution_strategy
...
...
@@ -85,14 +85,8 @@ class PowerSpace(Space):
dtype
=
self
.
dtype
)
def
weight
(
self
,
x
,
power
=
1
,
axes
=
None
,
inplace
=
False
):
total_shape
=
x
.
shape
axes
=
cast_axis_to_tuple
(
axes
,
len
(
total_shape
))
if
len
(
axes
)
!=
1
:
raise
ValueError
(
"axes must be of length 1."
)
reshaper
=
[
1
,
]
*
len
(
total_shape
)
reshaper
=
[
1
,
]
*
len
(
x
.
shape
)
# we know len(axes) is always 1
reshaper
[
axes
[
0
]]
=
self
.
shape
[
0
]
weight
=
self
.
rho
.
reshape
(
reshaper
)
...
...
@@ -179,8 +173,9 @@ class PowerSpace(Space):
new_ps
=
EmptyPowerSpace
()
# reset class
new_ps
.
__class__
=
cls
# call instructor so that classes are properly setup
super
(
PowerSpace
,
new_ps
).
__init__
(
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
]))
# set all values
new_ps
.
dtype
=
np
.
dtype
(
hdf5_group
.
attrs
[
'dtype'
])
new_ps
.
_harmonic_domain
=
repository
.
get
(
'harmonic_domain'
,
hdf5_group
)
new_ps
.
_log
=
hdf5_group
[
'log'
][()]
exec
(
'new_ps._nbin = '
+
hdf5_group
.
attrs
[
'nbin'
])
...
...
@@ -191,6 +186,8 @@ class PowerSpace(Space):
new_ps
.
_rho
=
hdf5_group
[
'rho'
][:]
new_ps
.
_pundex
=
hdf5_group
[
'pundex'
][:]
new_ps
.
_k_array
=
repository
.
get
(
'k_array'
,
hdf5_group
)
new_ps
.
_ignore_for_hash
+=
[
'_pindex'
,
'_kindex'
,
'_rho'
,
'_pundex'
,
'_k_array'
]
return
new_ps
...
...
nifty/spaces/rg_space/rg_space.py
View file @
c31c7198
...
...
@@ -150,7 +150,8 @@ class RGSpace(Space):
self
.
_distances
=
self
.
_parse_distances
(
distances
)
self
.
_zerocenter
=
self
.
_parse_zerocenter
(
zerocenter
)
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
):
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
,
preserve_gaussian_variance
=
False
):
# compute the hermitian part
flipped_x
=
self
.
_hermitianize_inverter
(
x
,
axes
=
axes
)
flipped_x
=
flipped_x
.
conjugate
()
...
...
@@ -160,8 +161,39 @@ class RGSpace(Space):
# use subtraction since it is faster than flipping another time
anti_hermitian_part
=
(
x
-
hermitian_part
)
/
1j