Skip to content
Snippets Groups Projects
Commit 365ebcbe authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge remote-tracking branch 'origin/NIFTy_5' into spectral_plotting

parents 668a1b80 09f38780
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,6 @@ from .field import Field
from .multi_field import MultiField
from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
......
......@@ -21,6 +21,7 @@ def _logger_init():
from . import dobj
res = logging.getLogger('NIFTy5')
res.setLevel(logging.DEBUG)
res.propagate = False
if dobj.rank == 0:
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import itertools
import numpy as np
from .. import dobj, utilities
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
# MR FIXME: for even axis lengths, we probably should split the value at the
# highest frequency.
class CentralZeroPadder(LinearOperator):
"""Operator that enlarges a fields domain by adding zeros from the middle.
Parameters
---------
domain: Domain, tuple of Domains or DomainTuple
The domain of the data that is input by "times" and output by
"adjoint_times"
new_shape: tuple
Shape of the target domain.
space: int, optional
The index of the subdomain on which the operator should act
If None, it is set to 0 if `domain` contains exactly one space.
`domain[space]` must be an RGSpace.
"""
def __init__(self, domain, new_shape, space=0):
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
# verify domains
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
# make target space
tgt = RGSpace(new_shape, dom.distances, harmonic=dom.harmonic)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
# define the axes along which the input field is sliced
slicer = []
axes = self._target.axes[self._space]
for i in range(len(self._domain.shape)):
if i in axes:
slicer_fw = slice(0, (self._domain.shape[i]+1)//2)
slicer_bw = slice(-1, -1-(self._domain.shape[i]//2), -1)
slicer.append((slicer_fw, slicer_bw))
self.slicer = list(itertools.product(*slicer))
for i in range(len(self.slicer)):
for j in range(len(self._domain.shape)):
if j not in axes:
tmp = list(self.slicer[i])
tmp.insert(j, slice(None))
self.slicer[i] = tuple(tmp)
self.slicer = tuple(self.slicer)
def apply(self, x, mode):
self._check_input(x, mode)
v = x.val
shp_out = self._tgt(mode).shape
v, x = dobj.ensure_not_distributed(v, self._target.axes[self._space])
curax = dobj.distaxis(v)
if mode == self.TIMES:
# slice along each axis and copy the data to an
# array of zeros which has the shape of the target domain
y = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
else:
# slice along each axis and copy the data to an array of zeros
# which has the shape of the input domain to remove excess zeros
y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
v = dobj.from_local_data(shp_out, y, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -216,16 +216,6 @@ def testZeroPadder(space, factor, dtype, central):
ift.extra.consistency_check(op, dtype, dtype)
@pmp('space', [0, 2])
@pmp('factor', [2, 2.7])
def testZeroPadder2(space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
newshape = [int(factor*l) for l in dom[space].shape]
op = ift.CentralZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
@pmp('args',
[(ift.RGSpace(10, harmonic=True), 4, 0), (ift.RGSpace(
(24, 31), distances=(0.4, 2.34), harmonic=True), (4, 3), 0),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment