Commit 3e395a7b authored by Martin Reinecke's avatar Martin Reinecke

add central padding support to FieldZeroPadder (experimental)

parent bb7fdd73
......@@ -133,6 +133,9 @@ class ChainOperator(LinearOperator):
x = op.apply(x, mode)
return x
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
......@@ -144,7 +147,3 @@ class ChainOperator(LinearOperator):
# for op in self._ops:
# samp = op.process_sample(samp, from_inverse)
# return samp
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
......@@ -29,9 +29,10 @@ from .linear_operator import LinearOperator
class FieldZeroPadder(LinearOperator):
def __init__(self, domain, new_shape, space=0):
def __init__(self, domain, new_shape, space=0, central=False):
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
self._central = central
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
......@@ -40,7 +41,7 @@ class FieldZeroPadder(LinearOperator):
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
if any([a <= b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
self._target = list(self._domain)
self._target[self._space] = RGSpace(new_shape, dom.distances)
......@@ -61,9 +62,35 @@ class FieldZeroPadder(LinearOperator):
shp = list(x.shape)
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew[idx + (slice(0, x.shape[d]),)] = x
if self._central:
Nyquist = x.shape[d]//2
i1 = idx + (slice(0, Nyquist+1),)
xnew[i1] = x[i1]
i1 = idx + (slice(None, -(Nyquist+1), -1),)
xnew[i1] = x[i1]
# if (x.shape[d] & 1) == 0: # even number of pixels
# print (Nyquist, x.shape[d]-Nyquist)
# i1 = idx+(Nyquist,)
# xnew[i1] *= 0.5
# i1 = idx+(-Nyquist,)
# xnew[i1] *= 0.5
else:
xnew[idx + (slice(0, x.shape[d]),)] = x
else: # ADJOINT_TIMES
xnew = x[idx + (slice(0, tgtshp[d]),)]
if self._central:
shp = list(x.shape)
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
Nyquist = xnew.shape[d]//2
i1 = idx + (slice(0, Nyquist+1),)
xnew[i1] = x[i1]
i1 = idx + (slice(None, -(Nyquist+1), -1),)
xnew[i1] += x[i1]
# if (xnew.shape[d] & 1) == 0: # even number of pixels
# i1 = idx+(Nyquist,)
# xnew[i1] *= 0.5
else:
xnew = x[idx + (slice(0, tgtshp[d]),)]
curshp[d] = xnew.shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
......
......@@ -51,7 +51,7 @@ class SlopeOperator(LinearOperator):
raise TypeError
self._domain = DomainTuple.make(UnstructuredDomain((2,)))
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._capability = self.TIMES | self.ADJOINT_TIMES
self._sigmas = sigmas
self.ndim = len(self.target[0].shape)
......
......@@ -201,19 +201,20 @@ class Consistency_Tests(unittest.TestCase):
op = ift.SymmetrizingOperator(dom, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder(self, space, factor, dtype):
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128],
[False, True]))
def testZeroPadder(self, space, factor, dtype, central):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
newshape = [factor*l for l in dom[space].shape]
op = ift.FieldZeroPadder(dom, newshape, space)
newshape = [int(factor*l) for l in dom[space].shape]
op = ift.FieldZeroPadder(dom, newshape, space, central)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder2(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
newshape = [factor*l for l in dom[space].shape]
newshape = [int(factor*l) for l in dom[space].shape]
op = ift.CentralZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
......@@ -243,7 +244,7 @@ class Consistency_Tests(unittest.TestCase):
@expand([[ift.RGSpace((13, 52, 40)), (4, 6, 25), None],
[ift.RGSpace((128, 128)), (45, 48), 0],
[ift.RGSpace(13), (7,), None],
[(ift.HPSpace(3), ift.RGSpace((12, 24),distances=0.3)),
[(ift.HPSpace(3), ift.RGSpace((12, 24), distances=0.3)),
(12, 12), 1]])
def testRegridding(self, domain, shape, space):
op = ift.RegriddingOperator(domain, shape, space)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment