Commit 3781a34e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

simplify power_analyze()

parent 482d8db2
Pipeline #18319 passed with stage
in 6 minutes and 5 seconds
......@@ -175,7 +175,7 @@ class Field(object):
Parameters
----------
spaces : int *optional*
The subspace for which the powerspectrum shall be computed
The subspace for which the powerspectrum shall be computed.
(default : None).
binbounds : array-like *optional*
Inner bounds of the bins (default : None).
......@@ -193,11 +193,8 @@ class Field(object):
Raise
-----
ValueError
Raised if
*len(domain) is != 1 when spaces==None
*len(spaces) is != 1 if not None
*the analyzed space is not harmonic
TypeError
Raised if any of the input field's domains is not harmonic
Returns
-------
......@@ -219,9 +216,10 @@ class Field(object):
"neither harmonic nor a PowerSpace.")
# check if the `spaces` input is valid
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if spaces is None:
spaces = range(len(self.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
if len(spaces) == 0:
raise ValueError("No space for analysis specified.")
......@@ -232,70 +230,37 @@ class Field(object):
parts = [self.real*self.real + self.imag*self.imag]
for space_index in spaces:
parts = [self._single_power_analyze(work_field=part,
space_index=space_index,
parts = [self._single_power_analyze(field=part,
idx=space_index,
binbounds=binbounds)
for part in parts]
return parts[0] + 1j*parts[1] if keep_phase_information else parts[0]
@staticmethod
def _single_power_analyze(work_field, space_index, binbounds):
if not work_field.domain[space_index].harmonic:
raise ValueError("The analyzed space must be harmonic.")
# Create the target PowerSpace instance:
# If the associated signal-space field was real, we extract the
# hermitian and anti-hermitian parts of `self` and put them
# into the real and imaginary parts of the power spectrum.
# If it was complex, all the power is put into a real power spectrum.
harmonic_domain = work_field.domain[space_index]
power_domain = PowerSpace(harmonic_partner=harmonic_domain,
binbounds=binbounds)
power_spectrum = Field._calculate_power_spectrum(
field_val=work_field.val,
pdomain=power_domain,
axes=work_field.domain_axes[space_index])
# create the result field and put power_spectrum into it
result_domain = list(work_field.domain)
result_domain[space_index] = power_domain
return Field(domain=result_domain, val=power_spectrum,
dtype=power_spectrum.dtype)
@staticmethod
def _calculate_power_spectrum(field_val, pdomain, axes=None):
pindex = pdomain.pindex
if axes is not None:
pindex = Field._shape_up_pindex(pindex, field_val.shape, axes)
def _single_power_analyze(field, idx, binbounds):
power_domain = PowerSpace(field.domain[idx], binbounds)
pindex = power_domain.pindex
axes = field.domain_axes[idx]
new_pindex_shape = [1] * len(field.shape)
for i, ax in enumerate(axes):
new_pindex_shape[ax] = pindex.shape[i]
pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape)
power_spectrum = utilities.bincount_axis(pindex, weights=field_val,
power_spectrum = utilities.bincount_axis(pindex, weights=field.val,
axis=axes)
rho = pdomain.rho
if axes is not None:
new_rho_shape = [1] * len(power_spectrum.shape)
new_rho_shape[axes[0]] = len(rho)
rho = rho.reshape(new_rho_shape)
power_spectrum /= rho
return power_spectrum
@staticmethod
def _shape_up_pindex(pindex, target_shape, axes):
semiscaled_local_shape = [1] * len(target_shape)
for i, ax in enumerate(axes):
semiscaled_local_shape[ax] = pindex.shape[i]
result_obj = np.empty(target_shape, dtype=pindex.dtype)
result_obj[()] = pindex.reshape(semiscaled_local_shape)
return result_obj
new_rho_shape[axes[0]] = len(power_domain.rho)
power_spectrum /= power_domain.rho.reshape(new_rho_shape)
result_domain = list(field.domain)
result_domain[idx] = power_domain
return Field(result_domain, power_spectrum)
def _compute_spec(self, spaces):
# check if the `spaces` input is valid
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if spaces is None:
spaces = range(len(self.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
# create the result domain
result_domain = list(self.domain)
......@@ -503,9 +468,10 @@ class Field(object):
"""
new_field = Field(val=self, copy=not inplace)
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if spaces is None:
spaces = range(len(self.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
fct = 1.
for ind in spaces:
......@@ -606,8 +572,8 @@ class Field(object):
def _contraction_helper(self, op, spaces):
if spaces is None:
return getattr(self.val, op)()
# build a list of all axes
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
......
......@@ -68,30 +68,12 @@ def get_slice_list(shape, axes):
yield [slice(None, None)]
def cast_axis_to_tuple(axis, length=None):
if axis is None:
def cast_iseq_to_tuple(seq):
if seq is None:
return None
try:
axis = tuple(int(item) for item in axis)
except(TypeError):
if np.isscalar(axis):
axis = (int(axis),)
else:
raise TypeError("Could not convert axis-input to tuple of ints")
if length is not None:
# shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# assert that all entries are elements in [0, length]
for elem in axis:
assert (0 <= elem < length)
return axis
if np.isscalar(seq):
return (int(seq),)
return tuple(int(item) for item in seq)
def parse_domain(domain):
......@@ -135,7 +117,7 @@ def bincount_axis(obj, minlength=None, weights=None, axis=None):
if axis is not None:
# do the reordering
ndim = len(obj.shape)
axis = sorted(cast_axis_to_tuple(axis, length=ndim))
axis = sorted(cast_iseq_to_tuple(axis))
reordering = [x for x in range(ndim) if x not in axis]
reordering += axis
......
......@@ -130,7 +130,7 @@ class FFTOperator(LinearOperator):
axes = x.domain_axes[0]
result_domain = other
else:
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
spaces = utilities.cast_iseq_to_tuple(spaces)
result_domain = list(x.domain)
result_domain[spaces[0]] = other[0]
axes = x.domain_axes[spaces[0]]
......
......@@ -89,13 +89,13 @@ class LaplaceOperator(EndomorphicOperator):
return self._logarithmic
def _times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
axes = x.domain_axes[spaces[0]]
axis = axes[0]
nval = len(self._dposc)
......@@ -115,13 +115,13 @@ class LaplaceOperator(EndomorphicOperator):
return Field(self.domain, val=ret).weight(power=-0.5, spaces=spaces)
def _adjoint_times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
axes = x.domain_axes[spaces[0]]
axis = axes[0]
nval = len(self._dposc)
......
......@@ -266,8 +266,9 @@ class LinearOperator(with_metaclass(
else:
spaces = self.default_spaces[::-1]
# sanitize the `spaces` and `types` input
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
# sanitize the `spaces` input
if spaces is not None:
spaces = utilities.cast_iseq_to_tuple(spaces)
# if the operator's domain is set to something, there are two valid
# cases:
......@@ -281,9 +282,8 @@ class LinearOperator(with_metaclass(
if spaces is None:
if self_domain != x.domain:
raise ValueError(
"The operator's and and field's domains don't "
"match.")
raise ValueError("The operator's and and field's domains "
"don't match.")
else:
for i, space_index in enumerate(spaces):
if x.domain[space_index] != self_domain[i]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment