Commit b2e8cde0 authored by Lukas Platz's avatar Lukas Platz
Browse files

add 'OffsetFieldZeroPadder'

parent 1db37c41
......@@ -32,7 +32,7 @@ from .operators.endomorphic_operator import EndomorphicOperator
from .operators.harmonic_operators import (
FFTOperator, HartleyOperator, SHTOperator, HarmonicTransformOperator,
HarmonicSmoothingOperator)
from .operators.field_zero_padder import FieldZeroPadder
from .operators.field_zero_padder import FieldZeroPadder, OffsetFieldZeroPadder
from .operators.inversion_enabler import InversionEnabler
from .operators.mask_operator import MaskOperator
from .operators.regridding_operator import RegriddingOperator
......
......@@ -112,3 +112,72 @@ class FieldZeroPadder(LinearOperator):
curshp[d] = xnew.shape[d]
v = xnew
return Field(self._tgt(mode), v)
class OffsetFieldZeroPadder(LinearOperator):
"""FieldZeroPadder with choosable offset
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
new_shape : list or tuple of int
The new dimensions of the subdomain which is zero-padded.
No entry must be smaller than the corresponding dimension in the
operator's domain.
space : int
The index of the subdomain to be zero-padded. If None, it is set to 0
if domain contains exactly one space. domain[space] must be an RGSpace.
offset : tuple of int or None
Where in the new zero-padded array to place the input field.
If `None` is given, place the field at zero offset.
"""
def __init__(self, domain, new_shape, space=0, offset=None):
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if offset is None:
self._offset = (0, ) * len(dom.shape)
else:
self._offset = offset
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if len(new_shape) != len(dom.shape):
raise ValueError("New shape mismatch")
if len(self._offset) != len(dom.shape):
raise ValueError("offset shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must not be smaller than old shape")
end_idx = [a + b for a, b in zip(dom.shape, self._offset)]
if any([a < b for a, b in zip(new_shape, end_idx)]):
raise ValueError(
"Input field pasted at offset would overflow target boundaries"
)
self._target = list(self._domain)
self._target[self._space] = RGSpace(new_shape, dom.distances,
dom.harmonic)
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
dom = self._domain[self._space]
num_spaces = len(self._target.axes)
idx = tuple()
for i in range(num_spaces):
for ax_idx in range(len(self._target.axes[i])):
if i == self._space:
ax_offset = self._offset[ax_idx]
idx += (slice(ax_offset, dom.shape[ax_idx] + ax_offset), )
else:
idx += (slice(None), )
if mode == self.TIMES:
xnew = np.zeros(self._target.shape, dtype=x.val.dtype)
xnew[idx] = x.val
else: # Adjoint times
xnew = x.val[idx]
return Field(self._tgt(mode), xnew)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment