From 8648a6b4bf2adda80b75f3e9d30098c3b8bd218e Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sat, 14 Jul 2018 21:10:44 +0200
Subject: [PATCH] generalize FieldZeroPadder

---
 nifty5/operators/field_zero_padder.py     | 24 +++++++++++------------
 nifty5/operators/qht_operator.py          |  1 -
 nifty5/operators/symmetrizing_operator.py |  1 -
 test/test_operators/test_adjoint.py       |  2 +-
 4 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/nifty5/operators/field_zero_padder.py b/nifty5/operators/field_zero_padder.py
index 5074aade5..267f59110 100644
--- a/nifty5/operators/field_zero_padder.py
+++ b/nifty5/operators/field_zero_padder.py
@@ -19,12 +19,11 @@ class FieldZeroPadder(LinearOperator):
         dom = self._domain[self._space]
         if not isinstance(dom, RGSpace):
             raise TypeError("RGSpace required")
-        if not len(dom.shape) == 1:
-            raise TypeError("RGSpace must be one-dimensional")
         if dom.harmonic:
             raise TypeError("RGSpace must not be harmonic")
 
-        tgt = RGSpace((int(factor*dom.shape[0]),), dom.distances)
+        newshp = tuple(factor*s for s in dom.shape)
+        tgt = RGSpace(newshp, dom.distances)
         self._target = list(self._domain)
         self._target[self._space] = tgt
         self._target = DomainTuple.make(self._target)
@@ -47,20 +46,21 @@ class FieldZeroPadder(LinearOperator):
         dax = dobj.distaxis(x)
         shp_in = x.shape
         shp_out = self._tgt(mode).shape
-        ax = self._target.axes[self._space][0]
-        if dax == ax:
-            x = dobj.redistribute(x, nodist=(ax,))
+        axbefore = self._target.axes[self._space][0]
+        axes = self._target.axes[self._space]
+        if dax in axes:
+            x = dobj.redistribute(x, nodist=axes)
         curax = dobj.distaxis(x)
 
         if mode == self.ADJOINT_TIMES:
             newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
-            newarr[()] = dobj.local_data(x)[(slice(None),)*ax +
-                                            (slice(0, shp_out[ax]),)]
+            sl = tuple(slice(0, shp_out[axis]) for axis in axes)
+            newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl]
         else:
             newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
-            newarr[(slice(None),)*ax +
-                   (slice(0, shp_in[ax]),)] = dobj.local_data(x)
+            sl = tuple(slice(0, shp_in[axis]) for axis in axes)
+            newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x)
         newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
-        if dax == ax:
-            newarr = dobj.redistribute(newarr, dist=ax)
+        if dax in axes:
+            newarr = dobj.redistribute(newarr, dist=dax)
         return Field(self._tgt(mode), val=newarr)
diff --git a/nifty5/operators/qht_operator.py b/nifty5/operators/qht_operator.py
index 3eebc3381..76d7b1d66 100644
--- a/nifty5/operators/qht_operator.py
+++ b/nifty5/operators/qht_operator.py
@@ -80,7 +80,6 @@ class QHTOperator(LinearOperator):
         n = self._domain.axes[self._space]
         rng = n if mode == self.TIMES else reversed(n)
         ax = dobj.distaxis(x)
-        globshape = x.shape
         for i in rng:
             sl = (slice(None),)*i + (slice(1, None),)
             if i == ax:
diff --git a/nifty5/operators/symmetrizing_operator.py b/nifty5/operators/symmetrizing_operator.py
index 8a2aa881e..a8c9e49ed 100644
--- a/nifty5/operators/symmetrizing_operator.py
+++ b/nifty5/operators/symmetrizing_operator.py
@@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator):
         self._check_input(x, mode)
         tmp = x.val.copy()
         ax = dobj.distaxis(tmp)
-        globshape = tmp.shape
         for i in self._domain.axes[self._space]:
             lead = (slice(None),)*i
             if i == ax:
diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py
index f6f9e9542..82064f7ca 100644
--- a/test/test_operators/test_adjoint.py
+++ b/test/test_operators/test_adjoint.py
@@ -119,7 +119,7 @@ class Consistency_Tests(unittest.TestCase):
 
     @expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
     def testZeroPadder(self, space, factor, dtype):
-        dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7),
+        dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
                ift.HPSpace(4))
         op = ift.FieldZeroPadder(dom, factor, space)
         ift.extra.consistency_check(op, dtype, dtype)
-- 
GitLab