Commit fcbd1ea9 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'even_more_operator_work' into 'NIFTy_5'

Even more operator work

See merge request ift/nifty-dev!59
parents 7c9500fa 2db0d555
......@@ -12,19 +12,21 @@ from .. import utilities
class FieldZeroPadder(LinearOperator):
def __init__(self, domain, factor, space=0):
def __init__(self, domain, new_shape, space=0):
super(FieldZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if not len(dom.shape) == 1:
raise TypeError("RGSpace must be one-dimensional")
if dom.harmonic:
raise TypeError("RGSpace must not be harmonic")
tgt = RGSpace((int(factor*dom.shape[0]),), dom.distances)
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
tgt = RGSpace(new_shape, dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
......@@ -47,20 +49,21 @@ class FieldZeroPadder(LinearOperator):
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
ax = self._target.axes[self._space][0]
if dax == ax:
x = dobj.redistribute(x, nodist=(ax,))
axbefore = self._target.axes[self._space][0]
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
if mode == self.ADJOINT_TIMES:
newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[()] = dobj.local_data(x)[(slice(None),)*ax +
(slice(0, shp_out[ax]),)]
sl = tuple(slice(0, shp_out[axis]) for axis in axes)
newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl]
else:
newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[(slice(None),)*ax +
(slice(0, shp_in[ax]),)] = dobj.local_data(x)
sl = tuple(slice(0, shp_in[axis]) for axis in axes)
newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x)
newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
if dax == ax:
newarr = dobj.redistribute(newarr, dist=ax)
if dax in axes:
newarr = dobj.redistribute(newarr, dist=dax)
return Field(self._tgt(mode), val=newarr)
......@@ -74,7 +74,6 @@ class QHTOperator(LinearOperator):
n = self._domain.axes[self._space]
rng = n if mode == self.TIMES else reversed(n)
ax = dobj.distaxis(x)
globshape = x.shape
for i in rng:
sl = (slice(None),)*i + (slice(1, None),)
if i == ax:
......
......@@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator):
self._check_input(x, mode)
tmp = x.val.copy()
ax = dobj.distaxis(tmp)
globshape = tmp.shape
for i in self._domain.axes[self._space]:
lead = (slice(None),)*i
if i == ax:
......
......@@ -119,9 +119,10 @@ class Consistency_Tests(unittest.TestCase):
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7),
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
op = ift.FieldZeroPadder(dom, factor, space)
newshape = [factor*l for l in dom[space].shape]
op = ift.FieldZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment