Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
3781a34e
Commit
3781a34e
authored
Sep 17, 2017
by
Martin Reinecke
Browse files
simplify power_analyze()
parent
482d8db2
Pipeline
#18319
passed with stage
in 6 minutes and 5 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty2go/field.py
View file @
3781a34e
...
...
@@ -175,7 +175,7 @@ class Field(object):
Parameters
----------
spaces : int *optional*
The subspace for which the powerspectrum shall be computed
The subspace for which the powerspectrum shall be computed
.
(default : None).
binbounds : array-like *optional*
Inner bounds of the bins (default : None).
...
...
@@ -193,11 +193,8 @@ class Field(object):
Raise
-----
ValueError
Raised if
*len(domain) is != 1 when spaces==None
*len(spaces) is != 1 if not None
*the analyzed space is not harmonic
TypeError
Raised if any of the input field's domains is not harmonic
Returns
-------
...
...
@@ -219,9 +216,10 @@ class Field(object):
"neither harmonic nor a PowerSpace."
)
# check if the `spaces` input is valid
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
else
:
spaces
=
utilities
.
cast_iseq_to_tuple
(
spaces
)
if
len
(
spaces
)
==
0
:
raise
ValueError
(
"No space for analysis specified."
)
...
...
@@ -232,70 +230,37 @@ class Field(object):
parts
=
[
self
.
real
*
self
.
real
+
self
.
imag
*
self
.
imag
]
for
space_index
in
spaces
:
parts
=
[
self
.
_single_power_analyze
(
work_
field
=
part
,
space_inde
x
=
space_index
,
parts
=
[
self
.
_single_power_analyze
(
field
=
part
,
id
x
=
space_index
,
binbounds
=
binbounds
)
for
part
in
parts
]
return
parts
[
0
]
+
1j
*
parts
[
1
]
if
keep_phase_information
else
parts
[
0
]
@
staticmethod
def
_single_power_analyze
(
work_field
,
space_index
,
binbounds
):
if
not
work_field
.
domain
[
space_index
].
harmonic
:
raise
ValueError
(
"The analyzed space must be harmonic."
)
# Create the target PowerSpace instance:
# If the associated signal-space field was real, we extract the
# hermitian and anti-hermitian parts of `self` and put them
# into the real and imaginary parts of the power spectrum.
# If it was complex, all the power is put into a real power spectrum.
harmonic_domain
=
work_field
.
domain
[
space_index
]
power_domain
=
PowerSpace
(
harmonic_partner
=
harmonic_domain
,
binbounds
=
binbounds
)
power_spectrum
=
Field
.
_calculate_power_spectrum
(
field_val
=
work_field
.
val
,
pdomain
=
power_domain
,
axes
=
work_field
.
domain_axes
[
space_index
])
# create the result field and put power_spectrum into it
result_domain
=
list
(
work_field
.
domain
)
result_domain
[
space_index
]
=
power_domain
return
Field
(
domain
=
result_domain
,
val
=
power_spectrum
,
dtype
=
power_spectrum
.
dtype
)
@
staticmethod
def
_calculate_power_spectrum
(
field_val
,
pdomain
,
axes
=
None
):
pindex
=
pdomain
.
pindex
if
axes
is
not
None
:
pindex
=
Field
.
_shape_up_pindex
(
pindex
,
field_val
.
shape
,
axes
)
def
_single_power_analyze
(
field
,
idx
,
binbounds
):
power_domain
=
PowerSpace
(
field
.
domain
[
idx
],
binbounds
)
pindex
=
power_domain
.
pindex
axes
=
field
.
domain_axes
[
idx
]
new_pindex_shape
=
[
1
]
*
len
(
field
.
shape
)
for
i
,
ax
in
enumerate
(
axes
):
new_pindex_shape
[
ax
]
=
pindex
.
shape
[
i
]
pindex
=
np
.
broadcast_to
(
pindex
.
reshape
(
new_pindex_shape
),
field
.
shape
)
power_spectrum
=
utilities
.
bincount_axis
(
pindex
,
weights
=
field
_
val
,
power_spectrum
=
utilities
.
bincount_axis
(
pindex
,
weights
=
field
.
val
,
axis
=
axes
)
rho
=
pdomain
.
rho
if
axes
is
not
None
:
new_rho_shape
=
[
1
]
*
len
(
power_spectrum
.
shape
)
new_rho_shape
[
axes
[
0
]]
=
len
(
rho
)
rho
=
rho
.
reshape
(
new_rho_shape
)
power_spectrum
/=
rho
return
power_spectrum
@
staticmethod
def
_shape_up_pindex
(
pindex
,
target_shape
,
axes
):
semiscaled_local_shape
=
[
1
]
*
len
(
target_shape
)
for
i
,
ax
in
enumerate
(
axes
):
semiscaled_local_shape
[
ax
]
=
pindex
.
shape
[
i
]
result_obj
=
np
.
empty
(
target_shape
,
dtype
=
pindex
.
dtype
)
result_obj
[()]
=
pindex
.
reshape
(
semiscaled_local_shape
)
return
result_obj
new_rho_shape
=
[
1
]
*
len
(
power_spectrum
.
shape
)
new_rho_shape
[
axes
[
0
]]
=
len
(
power_domain
.
rho
)
power_spectrum
/=
power_domain
.
rho
.
reshape
(
new_rho_shape
)
result_domain
=
list
(
field
.
domain
)
result_domain
[
idx
]
=
power_domain
return
Field
(
result_domain
,
power_spectrum
)
def
_compute_spec
(
self
,
spaces
):
# check if the `spaces` input is valid
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
else
:
spaces
=
utilities
.
cast_iseq_to_tuple
(
spaces
)
# create the result domain
result_domain
=
list
(
self
.
domain
)
...
...
@@ -503,9 +468,10 @@ class Field(object):
"""
new_field
=
Field
(
val
=
self
,
copy
=
not
inplace
)
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
self
.
domain
))
if
spaces
is
None
:
spaces
=
range
(
len
(
self
.
domain
))
else
:
spaces
=
utilities
.
cast_iseq_to_tuple
(
spaces
)
fct
=
1.
for
ind
in
spaces
:
...
...
@@ -606,8 +572,8 @@ class Field(object):
def
_contraction_helper
(
self
,
op
,
spaces
):
if
spaces
is
None
:
return
getattr
(
self
.
val
,
op
)()
# build a list of all axes
spaces
=
utilities
.
cast_
ax
is_to_tuple
(
spaces
,
len
(
self
.
domain
)
)
else
:
spaces
=
utilities
.
cast_is
eq
_to_tuple
(
spaces
)
axes_list
=
tuple
(
self
.
domain_axes
[
sp_index
]
for
sp_index
in
spaces
)
...
...
nifty2go/nifty_utilities.py
View file @
3781a34e
...
...
@@ -68,30 +68,12 @@ def get_slice_list(shape, axes):
yield
[
slice
(
None
,
None
)]
def
cast_
ax
is_to_tuple
(
axis
,
length
=
None
):
if
axis
is
None
:
def
cast_is
eq
_to_tuple
(
seq
):
if
seq
is
None
:
return
None
try
:
axis
=
tuple
(
int
(
item
)
for
item
in
axis
)
except
(
TypeError
):
if
np
.
isscalar
(
axis
):
axis
=
(
int
(
axis
),)
else
:
raise
TypeError
(
"Could not convert axis-input to tuple of ints"
)
if
length
is
not
None
:
# shift negative indices to positive ones
axis
=
tuple
(
item
if
(
item
>=
0
)
else
(
item
+
length
)
for
item
in
axis
)
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# assert that all entries are elements in [0, length]
for
elem
in
axis
:
assert
(
0
<=
elem
<
length
)
return
axis
if
np
.
isscalar
(
seq
):
return
(
int
(
seq
),)
return
tuple
(
int
(
item
)
for
item
in
seq
)
def
parse_domain
(
domain
):
...
...
@@ -135,7 +117,7 @@ def bincount_axis(obj, minlength=None, weights=None, axis=None):
if
axis
is
not
None
:
# do the reordering
ndim
=
len
(
obj
.
shape
)
axis
=
sorted
(
cast_
ax
is_to_tuple
(
axis
,
length
=
ndim
))
axis
=
sorted
(
cast_is
eq
_to_tuple
(
axis
))
reordering
=
[
x
for
x
in
range
(
ndim
)
if
x
not
in
axis
]
reordering
+=
axis
...
...
nifty2go/operators/fft_operator/fft_operator.py
View file @
3781a34e
...
...
@@ -130,7 +130,7 @@ class FFTOperator(LinearOperator):
axes
=
x
.
domain_axes
[
0
]
result_domain
=
other
else
:
spaces
=
utilities
.
cast_
ax
is_to_tuple
(
spaces
,
len
(
x
.
domain
)
)
spaces
=
utilities
.
cast_is
eq
_to_tuple
(
spaces
)
result_domain
=
list
(
x
.
domain
)
result_domain
[
spaces
[
0
]]
=
other
[
0
]
axes
=
x
.
domain_axes
[
spaces
[
0
]]
...
...
nifty2go/operators/laplace_operator/laplace_operator.py
View file @
3781a34e
...
...
@@ -89,13 +89,13 @@ class LaplaceOperator(EndomorphicOperator):
return
self
.
_logarithmic
def
_times
(
self
,
x
,
spaces
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
if
spaces
is
None
:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes
=
x
.
domain_axes
[
0
]
else
:
spaces
=
utilities
.
cast_iseq_to_tuple
(
spaces
)
axes
=
x
.
domain_axes
[
spaces
[
0
]]
axis
=
axes
[
0
]
nval
=
len
(
self
.
_dposc
)
...
...
@@ -115,13 +115,13 @@ class LaplaceOperator(EndomorphicOperator):
return
Field
(
self
.
domain
,
val
=
ret
).
weight
(
power
=-
0.5
,
spaces
=
spaces
)
def
_adjoint_times
(
self
,
x
,
spaces
):
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
if
spaces
is
None
:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes
=
x
.
domain_axes
[
0
]
else
:
spaces
=
utilities
.
cast_iseq_to_tuple
(
spaces
)
axes
=
x
.
domain_axes
[
spaces
[
0
]]
axis
=
axes
[
0
]
nval
=
len
(
self
.
_dposc
)
...
...
nifty2go/operators/linear_operator/linear_operator.py
View file @
3781a34e
...
...
@@ -266,8 +266,9 @@ class LinearOperator(with_metaclass(
else
:
spaces
=
self
.
default_spaces
[::
-
1
]
# sanitize the `spaces` and `types` input
spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
,
len
(
x
.
domain
))
# sanitize the `spaces` input
if
spaces
is
not
None
:
spaces
=
utilities
.
cast_iseq_to_tuple
(
spaces
)
# if the operator's domain is set to something, there are two valid
# cases:
...
...
@@ -281,9 +282,8 @@ class LinearOperator(with_metaclass(
if
spaces
is
None
:
if
self_domain
!=
x
.
domain
:
raise
ValueError
(
"The operator's and and field's domains don't "
"match."
)
raise
ValueError
(
"The operator's and and field's domains "
"don't match."
)
else
:
for
i
,
space_index
in
enumerate
(
spaces
):
if
x
.
domain
[
space_index
]
!=
self_domain
[
i
]:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment