Skip to content
Snippets Groups Projects
Commit 8648a6b4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

generalize FieldZeroPadder

parent af1d849c
No related branches found
No related tags found
No related merge requests found
......@@ -19,12 +19,11 @@ class FieldZeroPadder(LinearOperator):
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if not len(dom.shape) == 1:
raise TypeError("RGSpace must be one-dimensional")
if dom.harmonic:
raise TypeError("RGSpace must not be harmonic")
tgt = RGSpace((int(factor*dom.shape[0]),), dom.distances)
newshp = tuple(factor*s for s in dom.shape)
tgt = RGSpace(newshp, dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
......@@ -47,20 +46,21 @@ class FieldZeroPadder(LinearOperator):
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
ax = self._target.axes[self._space][0]
if dax == ax:
x = dobj.redistribute(x, nodist=(ax,))
axbefore = self._target.axes[self._space][0]
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
if mode == self.ADJOINT_TIMES:
newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[()] = dobj.local_data(x)[(slice(None),)*ax +
(slice(0, shp_out[ax]),)]
sl = tuple(slice(0, shp_out[axis]) for axis in axes)
newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl]
else:
newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[(slice(None),)*ax +
(slice(0, shp_in[ax]),)] = dobj.local_data(x)
sl = tuple(slice(0, shp_in[axis]) for axis in axes)
newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x)
newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
if dax == ax:
newarr = dobj.redistribute(newarr, dist=ax)
if dax in axes:
newarr = dobj.redistribute(newarr, dist=dax)
return Field(self._tgt(mode), val=newarr)
......@@ -80,7 +80,6 @@ class QHTOperator(LinearOperator):
n = self._domain.axes[self._space]
rng = n if mode == self.TIMES else reversed(n)
ax = dobj.distaxis(x)
globshape = x.shape
for i in rng:
sl = (slice(None),)*i + (slice(1, None),)
if i == ax:
......
......@@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator):
self._check_input(x, mode)
tmp = x.val.copy()
ax = dobj.distaxis(tmp)
globshape = tmp.shape
for i in self._domain.axes[self._space]:
lead = (slice(None),)*i
if i == ax:
......
......@@ -119,7 +119,7 @@ class Consistency_Tests(unittest.TestCase):
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7),
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
op = ift.FieldZeroPadder(dom, factor, space)
ift.extra.consistency_check(op, dtype, dtype)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment