Commit 37e77ddb authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'operator_documentation' of gitlab.mpcdf.mpg.de:ift/nifty-dev...

Merge branch 'operator_documentation' of gitlab.mpcdf.mpg.de:ift/nifty-dev into operator_documentation
parents c1853600 f156b5fe
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
......@@ -15,11 +33,28 @@ from .. import dobj
# MR FIXME: for even axis lengths, we probably should split the value at the
# highest frequency.
class CentralZeroPadder(LinearOperator):
"""Operator that enlarges a fields domain by adding zeros from the middle.
Parameters
---------
domain: Domain, tuple of Domains or DomainTuple
The domain of the data that is input by "times" and output by "adjoint_times"
new_shape: tuple
Shape of the target domain.
space: int, optional
The index of the subdomain on which the operator should act
If None, it is set to 0 if `domain` contains exactly one space.
`domain[space]` must be an RGSpace.
"""
def __init__(self, domain, new_shape, space=0):
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
# verify domains
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if dom.harmonic:
......@@ -29,12 +64,15 @@ class CentralZeroPadder(LinearOperator):
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
# make target space
tgt = RGSpace(new_shape, dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
# define the axes along which the input filed is sliced
slicer = []
axes = self._target.axes[self._space]
for i in range(len(self._domain.shape)):
......@@ -65,10 +103,14 @@ class CentralZeroPadder(LinearOperator):
x = dobj.local_data(x)
if mode == self.TIMES:
# slice along each axis and copy the data to an
# array of zeros which has the shape of the target domain
y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
else:
# slice along each axis and copy the data to an array of zeros
# which has the shape of the input domain to remove excess zeros
y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment