Commit fcbd1ea9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'even_more_operator_work' into 'NIFTy_5'

Even more operator work

See merge request ift/nifty-dev!59
parents 7c9500fa 2db0d555
...@@ -12,19 +12,21 @@ from .. import utilities ...@@ -12,19 +12,21 @@ from .. import utilities
class FieldZeroPadder(LinearOperator): class FieldZeroPadder(LinearOperator):
def __init__(self, domain, factor, space=0): def __init__(self, domain, new_shape, space=0):
super(FieldZeroPadder, self).__init__() super(FieldZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space) self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space] dom = self._domain[self._space]
if not isinstance(dom, RGSpace): if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required") raise TypeError("RGSpace required")
if not len(dom.shape) == 1:
raise TypeError("RGSpace must be one-dimensional")
if dom.harmonic: if dom.harmonic:
raise TypeError("RGSpace must not be harmonic") raise TypeError("RGSpace must not be harmonic")
tgt = RGSpace((int(factor*dom.shape[0]),), dom.distances) if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
tgt = RGSpace(new_shape, dom.distances)
self._target = list(self._domain) self._target = list(self._domain)
self._target[self._space] = tgt self._target[self._space] = tgt
self._target = DomainTuple.make(self._target) self._target = DomainTuple.make(self._target)
...@@ -47,20 +49,21 @@ class FieldZeroPadder(LinearOperator): ...@@ -47,20 +49,21 @@ class FieldZeroPadder(LinearOperator):
dax = dobj.distaxis(x) dax = dobj.distaxis(x)
shp_in = x.shape shp_in = x.shape
shp_out = self._tgt(mode).shape shp_out = self._tgt(mode).shape
ax = self._target.axes[self._space][0] axbefore = self._target.axes[self._space][0]
if dax == ax: axes = self._target.axes[self._space]
x = dobj.redistribute(x, nodist=(ax,)) if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x) curax = dobj.distaxis(x)
if mode == self.ADJOINT_TIMES: if mode == self.ADJOINT_TIMES:
newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype) newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[()] = dobj.local_data(x)[(slice(None),)*ax + sl = tuple(slice(0, shp_out[axis]) for axis in axes)
(slice(0, shp_out[ax]),)] newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl]
else: else:
newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype) newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[(slice(None),)*ax + sl = tuple(slice(0, shp_in[axis]) for axis in axes)
(slice(0, shp_in[ax]),)] = dobj.local_data(x) newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x)
newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax) newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
if dax == ax: if dax in axes:
newarr = dobj.redistribute(newarr, dist=ax) newarr = dobj.redistribute(newarr, dist=dax)
return Field(self._tgt(mode), val=newarr) return Field(self._tgt(mode), val=newarr)
...@@ -74,7 +74,6 @@ class QHTOperator(LinearOperator): ...@@ -74,7 +74,6 @@ class QHTOperator(LinearOperator):
n = self._domain.axes[self._space] n = self._domain.axes[self._space]
rng = n if mode == self.TIMES else reversed(n) rng = n if mode == self.TIMES else reversed(n)
ax = dobj.distaxis(x) ax = dobj.distaxis(x)
globshape = x.shape
for i in rng: for i in rng:
sl = (slice(None),)*i + (slice(1, None),) sl = (slice(None),)*i + (slice(1, None),)
if i == ax: if i == ax:
......
...@@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator): ...@@ -43,7 +43,6 @@ class SymmetrizingOperator(EndomorphicOperator):
self._check_input(x, mode) self._check_input(x, mode)
tmp = x.val.copy() tmp = x.val.copy()
ax = dobj.distaxis(tmp) ax = dobj.distaxis(tmp)
globshape = tmp.shape
for i in self._domain.axes[self._space]: for i in self._domain.axes[self._space]:
lead = (slice(None),)*i lead = (slice(None),)*i
if i == ax: if i == ax:
......
...@@ -119,9 +119,10 @@ class Consistency_Tests(unittest.TestCase): ...@@ -119,9 +119,10 @@ class Consistency_Tests(unittest.TestCase):
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128])) @expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder(self, space, factor, dtype): def testZeroPadder(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7), dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4)) ift.HPSpace(4))
op = ift.FieldZeroPadder(dom, factor, space) newshape = [factor*l for l in dom[space].shape]
op = ift.FieldZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype) ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0), @expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment