diff --git a/nifty5/operators/field_zero_padder.py b/nifty5/operators/field_zero_padder.py index 5074aade57910913d7584247ed23425ddc348473..267f591107e867c0d36f5278eb7a7c2d2c50ed65 100644 --- a/nifty5/operators/field_zero_padder.py +++ b/nifty5/operators/field_zero_padder.py @@ -19,12 +19,11 @@ class FieldZeroPadder(LinearOperator): dom = self._domain[self._space] if not isinstance(dom, RGSpace): raise TypeError("RGSpace required") - if not len(dom.shape) == 1: - raise TypeError("RGSpace must be one-dimensional") if dom.harmonic: raise TypeError("RGSpace must not be harmonic") - tgt = RGSpace((int(factor*dom.shape[0]),), dom.distances) + newshp = tuple(factor*s for s in dom.shape) + tgt = RGSpace(newshp, dom.distances) self._target = list(self._domain) self._target[self._space] = tgt self._target = DomainTuple.make(self._target) @@ -47,20 +46,21 @@ class FieldZeroPadder(LinearOperator): dax = dobj.distaxis(x) shp_in = x.shape shp_out = self._tgt(mode).shape - ax = self._target.axes[self._space][0] - if dax == ax: - x = dobj.redistribute(x, nodist=(ax,)) + axbefore = self._target.axes[self._space][0] + axes = self._target.axes[self._space] + if dax in axes: + x = dobj.redistribute(x, nodist=axes) curax = dobj.distaxis(x) if mode == self.ADJOINT_TIMES: newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype) - newarr[()] = dobj.local_data(x)[(slice(None),)*ax + - (slice(0, shp_out[ax]),)] + sl = tuple(slice(0, shp_out[axis]) for axis in axes) + newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl] else: newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype) - newarr[(slice(None),)*ax + - (slice(0, shp_in[ax]),)] = dobj.local_data(x) + sl = tuple(slice(0, shp_in[axis]) for axis in axes) + newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x) newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax) - if dax == ax: - newarr = dobj.redistribute(newarr, dist=ax) + if dax in axes: + newarr = dobj.redistribute(newarr, dist=dax) return Field(self._tgt(mode), val=newarr) diff --git a/nifty5/operators/qht_operator.py b/nifty5/operators/qht_operator.py index 3eebc338146c167b3ebc40b837291b00accccf0d..76d7b1d6649c7b66813dec3cde387384d3c67658 100644 --- a/nifty5/operators/qht_operator.py +++ b/nifty5/operators/qht_operator.py @@ -80,7 +80,6 @@ class QHTOperator(LinearOperator): n = self._domain.axes[self._space] rng = n if mode == self.TIMES else reversed(n) ax = dobj.distaxis(x) - globshape = x.shape for i in rng: sl = (slice(None),)*i + (slice(1, None),) if i == ax: diff --git a/nifty5/operators/symmetrizing_operator.py b/nifty5/operators/symmetrizing_operator.py index 8a2aa881e6c2b3e49cdba19275a2d9b16b54b89a..a8c9e49eda5fac886e2013480c4f67ff77143e89 100644 --- a/nifty5/operators/symmetrizing_operator.py +++ b/nifty5/operators/symmetrizing_operator.py @@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator): self._check_input(x, mode) tmp = x.val.copy() ax = dobj.distaxis(tmp) - globshape = tmp.shape for i in self._domain.axes[self._space]: lead = (slice(None),)*i if i == ax: diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index f6f9e9542e5cd41d35eb7d9444164e663b2513cd..82064f7caa864c306c96561b158a79eaf22d1457 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -119,7 +119,7 @@ class Consistency_Tests(unittest.TestCase): @expand(product([0, 2], [2, 2.7], [np.float64, np.complex128])) def testZeroPadder(self, space, factor, dtype): - dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7), + dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12), ift.HPSpace(4)) op = ift.FieldZeroPadder(dom, factor, space) ift.extra.consistency_check(op, dtype, dtype)