Commit 50ffbdcc authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'central_zero_padder' into 'NIFTy_5'

Central Zero Padder

See merge request ift/nifty-dev!61
parents cf2e08c2 ea595092
......@@ -26,6 +26,7 @@ from .models.model import Model
from .models.multi_model import MultiModel
from .models.variable import Variable
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.dof_distributor import DOFDistributor
from .operators.domain_distributor import DomainDistributor
import numpy as np
import itertools
from .. import utilities
from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple
from import RGSpace
from ..field import Field
from .. import dobj
# MR FIXME: for even axis lengths, we probably should split the value at the
# highest frequency.
class CentralZeroPadder(LinearOperator):
def __init__(self, domain, new_shape, space=0):
super(CentralZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if dom.harmonic:
raise TypeError("RGSpace must not be harmonic")
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
tgt = RGSpace(new_shape, dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
slicer = []
axes = self._target.axes[self._space]
for i in range(len(self._domain.shape)):
if i in axes:
slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
slicer.append([slicer_fw, slicer_bw])
self.slicer = list(itertools.product(*slicer))
for i in range(len(self.slicer)):
for j in range(len(self._domain.shape)):
if j not in axes:
tmp = list(self.slicer[i])
tmp.insert(j, slice(None))
self.slicer[i] = tmp
def domain(self):
return self._domain
def target(self):
return self._target
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
x = dobj.local_data(x)
if mode == self.TIMES:
y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
y = dobj.from_local_data(shp_out, y, distaxis=curax)
if dax in axes:
y = dobj.redistribute(y, dist=dax)
return Field(self._tgt(mode), val=y)
......@@ -209,6 +209,14 @@ class Consistency_Tests(unittest.TestCase):
op = ift.FieldZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder2(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
newshape = [factor*l for l in dom[space].shape]
op = ift.CentralZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
harmonic=True), (4, 3), 0),
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment