Commit 7be2eb8f authored by Jakob Roth's avatar Jakob Roth
Browse files

fix SliceOperator for multidimensional spaces

parent a55108a9
Pipeline #95051 passed with stages
in 11 minutes and 47 seconds
......@@ -37,9 +37,11 @@ class SliceOperator(LinearOperator):
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
new_shape : tuple of integers or None
new_shape : tuple of tuples or integers, or None
The shape of the target domain with None indicating to copy the shape
of the original domain for this axis.
of the original domain for this axis. For example ((10, 5), 100) for a
DomainTuple with two entires, the first having shape (10, 5) and the
second having shape 100
center : bool, optional
Whether to center the slice that is selected in the input field.
preserve_dist: bool, optional
......@@ -47,19 +49,25 @@ class SliceOperator(LinearOperator):
"""
def __init__(self, domain, new_shape, center=False, preserve_dist=True):
self._domain = DomainTuple.make(domain)
if len(new_shape) != len(self._domain.shape):
if len(new_shape) != len(self._domain):
ve = (
f"shape ({new_shape}) is incompatible with the shape of the"
f" domain ({self._domain.shape})"
)
raise ValueError(ve)
for i, shape in enumerate(new_shape):
if len(np.atleast_1d(shape)) != len(self._domain[i].shape):
ve = (
f"shape of subspace ({i}) is incompatible with the domain"
)
raise ValueError(ve)
tgt = []
slc_by_ax = []
for i, d in enumerate(self._domain):
if new_shape[i] is None or self._domain.shape[i] == new_shape[i]:
tgt += [d]
elif new_shape[i] < self._domain.shape[i]:
elif np.all(np.array(new_shape[i]) < np.array(d.shape)):
dom_kw = dict()
if isinstance(d, RGSpace):
if preserve_dist:
......@@ -78,14 +86,15 @@ class SliceOperator(LinearOperator):
raise ValueError(ve)
if center:
slc_start = np.floor(
(self._domain.shape[i] - new_shape[i]) / 2.
).astype(int)
slc_end = slc_start + new_shape[i]
for j, n_pix in enumerate(np.atleast_1d(new_shape[i])):
slc_start = np.floor((d.shape[j] - n_pix) / 2.).astype(int)
slc_end = slc_start + n_pix
slc_by_ax += [slice(slc_start, slc_end)]
else:
slc_start = 0
slc_end = new_shape[i]
slc_by_ax += [slice(slc_start, slc_end)]
for n_pix in np.atleast_1d(new_shape[i]):
slc_start = 0
slc_end = n_pix
slc_by_ax += [slice(slc_start, slc_end)]
self._slc_by_ax = tuple(slc_by_ax)
self._target = DomainTuple.make(tgt)
......@@ -102,8 +111,10 @@ class SliceOperator(LinearOperator):
return Field.from_raw(self.domain, res)
def __str__(self):
ss = (f"{self.__class__.__name__}"
f"({self.domain.shape} -> {self.target.shape})")
ss = (
f"{self.__class__.__name__}"
f"({self.domain.shape} -> {self.target.shape})"
)
return ss
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment