# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2019 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import itertools import numpy as np from .. import dobj, utilities from ..domain_tuple import DomainTuple from ..domains.rg_space import RGSpace from ..field import Field from .linear_operator import LinearOperator # MR FIXME: for even axis lengths, we probably should split the value at the # highest frequency. class CentralZeroPadder(LinearOperator): """Operator that enlarges a fields domain by adding zeros from the middle. Parameters --------- domain: Domain, tuple of Domains or DomainTuple The domain of the data that is input by "times" and output by "adjoint_times" new_shape: tuple Shape of the target domain. space: int, optional The index of the subdomain on which the operator should act If None, it is set to 0 if `domain` contains exactly one space. `domain[space]` must be an RGSpace. """ def __init__(self, domain, new_shape, space=0): self._domain = DomainTuple.make(domain) self._space = utilities.infer_space(self._domain, space) dom = self._domain[self._space] # verify domains if not isinstance(dom, RGSpace): raise TypeError("RGSpace required") if len(new_shape) != len(dom.shape): raise ValueError("Shape mismatch") if any([a < b for a, b in zip(new_shape, dom.shape)]): raise ValueError("New shape must be larger than old shape") # make target space tgt = RGSpace(new_shape, dom.distances, harmonic=dom.harmonic) self._target = list(self._domain) self._target[self._space] = tgt self._target = DomainTuple.make(self._target) self._capability = self.TIMES | self.ADJOINT_TIMES # define the axes along which the input field is sliced slicer = [] axes = self._target.axes[self._space] for i in range(len(self._domain.shape)): if i in axes: slicer_fw = slice(0, (self._domain.shape[i]+1)//2) slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1) slicer.append((slicer_fw, slicer_bw)) self.slicer = list(itertools.product(*slicer)) for i in range(len(self.slicer)): for j in range(len(self._domain.shape)): if j not in axes: tmp = list(self.slicer[i]) tmp.insert(j, slice(None)) self.slicer[i] = tuple(tmp) self.slicer = tuple(self.slicer) def apply(self, x, mode): self._check_input(x, mode) v = x.val shp_out = self._tgt(mode).shape v, x = dobj.ensure_not_distributed(v, self._target.axes[self._space]) curax = dobj.distaxis(v) if mode == self.TIMES: # slice along each axis and copy the data to an # array of zeros which has the shape of the target domain y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype) for i in self.slicer: y[i] = x[i] else: # slice along each axis and copy the data to an array of zeros # which has the shape of the input domain to remove excess zeros y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype) for i in self.slicer: y[i] = x[i] v = dobj.from_local_data(shp_out, y, distaxis=dobj.distaxis(v)) return Field(self._tgt(mode), dobj.ensure_default_distributed(v))