Commit da95d1c9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

make Zeropadder more flexible

parent ad593f04
......@@ -41,8 +41,8 @@ class FieldZeroPadder(LinearOperator):
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a <= b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must not be smaller than old shape")
self._target = list(self._domain)
self._target[self._space] = RGSpace(new_shape, dom.distances)
self._target = DomainTuple.make(self._target)
......@@ -54,6 +54,9 @@ class FieldZeroPadder(LinearOperator):
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
for d in self._target.axes[self._space]:
if v.shape[d] == tgtshp[d]: # nothing to do
continue
idx = (slice(None),) * d
v, x = dobj.ensure_not_distributed(v, (d,))
......
......@@ -209,7 +209,7 @@ class Consistency_Tests(unittest.TestCase):
op = ift.SymmetrizingOperator(dom, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128],
@expand(product([0, 2], [1, 2, 2.7], [np.float64, np.complex128],
[False, True]))
def testZeroPadder(self, space, factor, dtype, central):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment