Commit b20ee9ff authored by Lukas Platz's avatar Lukas Platz
Browse files

make CentralFieldZeroPadder split the central values per default on harmonic spaces

parent a9c0d78c
Pipeline #107094 passed with stages
in 20 minutes and 55 seconds
...@@ -39,28 +39,37 @@ class CentralFieldZeroPadder(LinearOperator): ...@@ -39,28 +39,37 @@ class CentralFieldZeroPadder(LinearOperator):
space : int space : int
The index of the subdomain to be zero-padded. If None, it is set to 0 The index of the subdomain to be zero-padded. If None, it is set to 0
if domain contains exactly one space. domain[space] must be an RGSpace. if domain contains exactly one space. domain[space] must be an RGSpace.
split_even : boolean split_even : None or boolean
When set to True and padding on an axis with even length, the When doing central padding on an axis with an even length, the "central"
"central" entry will be split up. This is useful for padding in entry will get distributed to two locations. When padding harmonic fields,
harmonic spaces. only half of the value should be written to each of the target locations
to preserve the total power.
If True or False, the splitting is or is not performed regardless of
input. If set to None (default), splitting will be enabled for harmonic
spaces and disabled otherwise.
""" """
def __init__(self, domain, new_shape, space=0, split_even=False): def __init__(self, domain, new_shape, space=0, split_even=None):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space) self._space = utilities.infer_space(self._domain, space)
self._split_even = split_even
dom = self._domain[self._space] dom = self._domain[self._space]
if not isinstance(dom, RGSpace): if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required") raise TypeError("RGSpace required")
if len(new_shape) != len(dom.shape): if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch") raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]): if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must not be smaller than old shape") raise ValueError("New shape must not be smaller than old shape")
self._target = list(self._domain) self._target = list(self._domain)
self._target[self._space] = RGSpace(new_shape, dom.distances, self._target[self._space] = RGSpace(new_shape, dom.distances,
dom.harmonic) dom.harmonic)
self._target = DomainTuple.make(self._target) self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
if split_even is None:
split_even = True if dom.harmonic else False
self._split_even = split_even
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
v = x.val v = x.val
......
...@@ -208,10 +208,10 @@ def testDomainTupleFieldInserter(): ...@@ -208,10 +208,10 @@ def testDomainTupleFieldInserter():
@pmp('space', [0, 2]) @pmp('space', [0, 2])
@pmp('factor', [1, 2, 2.7]) @pmp('factor', [1, 2, 2.7])
@pmp('split_even', [False, True]) @pmp('split_even', [None, False, True])
def testCentralZeroPadder(space, factor, dtype, split_even): def testCentralZeroPadder(space, factor, dtype, split_even):
dom = (ift.RGSpace(4), ift.UnstructuredDomain(5), ift.RGSpace(3, 4), dom = (ift.RGSpace(4, harmonic=True), ift.UnstructuredDomain(5),
ift.HPSpace(2)) ift.RGSpace(3, 4), ift.HPSpace(2))
newshape = [int(factor*ll) for ll in dom[space].shape] newshape = [int(factor*ll) for ll in dom[space].shape]
op = ift.CentralFieldZeroPadder(dom, newshape, space, split_even) op = ift.CentralFieldZeroPadder(dom, newshape, space, split_even)
ift.extra.check_linear_operator(op, dtype, dtype) ift.extra.check_linear_operator(op, dtype, dtype)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment