Commit fc9e9bde authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'regridding_operator' into 'NIFTy_5'

Regridding operator

See merge request ift/nifty-dev!81
parents e7f2378c b09688f3
......@@ -34,6 +34,7 @@ from .operators.laplace_operator import LaplaceOperator
from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator
from .operators.qht_operator import QHTOperator
from .operators.regridding_operator import RegriddingOperator
from .operators.sampling_enabler import SamplingEnabler
from .operators.sandwich_operator import SandwichOperator
from .operators.scaling_operator import ScalingOperator
......
......@@ -33,7 +33,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw"]
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -540,3 +541,15 @@ def lock(arr):
def locked(arr):
return not arr._data.flags.writeable
def ensure_not_distributed(arr, axes):
if arr._distaxis in axes:
arr = redistribute(arr, nodist=axes)
return arr, arr._data
def ensure_default_distributed(arr):
if arr._distaxis != 0:
arr = redistribute(arr, dist=0)
return arr
......@@ -32,7 +32,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"np_allreduce_min", "np_allreduce_max",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy",
"lock", "locked", "uniform_full", "to_global_data_rw"]
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
ntask = 1
rank = 0
......@@ -132,3 +133,11 @@ def locked(arr):
def uniform_full(shape, fill_value, dtype=None, distaxis=-1):
return np.broadcast_to(fill_value, shape)
def ensure_not_distributed(arr, axes):
return arr, arr
def ensure_default_distributed(arr):
return arr
......@@ -93,14 +93,10 @@ class CentralZeroPadder(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
v = x.val
shp_out = self._tgt(mode).shape
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
x = dobj.local_data(x)
v, x = dobj.ensure_not_distributed(v, self._target.axes[self._space])
curax = dobj.distaxis(v)
if mode == self.TIMES:
# slice along each axis and copy the data to an
......@@ -114,7 +110,5 @@ class CentralZeroPadder(LinearOperator):
y = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
for i in self.slicer:
y[i] = x[i]
y = dobj.from_local_data(shp_out, y, distaxis=curax)
if dax in axes:
y = dobj.redistribute(y, dist=dax)
return Field(self._tgt(mode), val=y)
v = dobj.from_local_data(shp_out, y, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -105,19 +105,15 @@ class ExpTransform(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
ax = dobj.distaxis(x)
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None,),) * d
idx = (slice(None),) * d
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
v, x = dobj.ensure_not_distributed(v, (d,))
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
......@@ -130,7 +126,5 @@ class ExpTransform(LinearOperator):
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
if d == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -322,23 +322,19 @@ class SHTOperator(LinearOperator):
def _apply_spherical(self, x, mode):
axes = x.domain.axes[self._space]
axis = axes[0]
tval = x.val
if dobj.distaxis(tval) == axis:
tval = dobj.redistribute(tval, nodist=(axis,))
distaxis = dobj.distaxis(tval)
v = x.val
v, idat = dobj.ensure_not_distributed(v, (axis,))
distaxis = dobj.distaxis(v)
p2h = not x.domain[self._space].harmonic
tdom = self._tgt(mode)
func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = func(idat[slice])
odat = dobj.from_local_data(tdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(tdom, odat)
return Field(tdom, dobj.ensure_default_distributed(odat))
class HarmonicTransformOperator(LinearOperator):
......
......@@ -78,10 +78,7 @@ class LaplaceOperator(EndomorphicOperator):
sl_r = prefix + (slice(1, None),) # "right" slice
dpos = self._dpos.reshape((1,)*axis + (nval-1,))
dposc = self._dposc.reshape((1,)*axis + (nval,))
locval = x.val
if axis == dobj.distaxis(locval):
locval = dobj.redistribute(locval, nodist=(axis,))
val = dobj.local_data(locval)
v, val = dobj.ensure_not_distributed(x.val, (axis,))
ret = np.empty_like(val)
if mode == self.TIMES:
deriv = (val[sl_r]-val[sl_l])/dpos # defined between points
......@@ -99,7 +96,5 @@ class LaplaceOperator(EndomorphicOperator):
ret[sl_l] = deriv
ret[prefix + (-1,)] = 0.
ret[sl_r] -= deriv
ret = dobj.from_local_data(locval.shape, ret, dobj.distaxis(locval))
if dobj.distaxis(locval) != dobj.distaxis(x.val):
ret = dobj.redistribute(ret, dist=dobj.distaxis(x.val))
return Field(self.domain, val=ret)
ret = dobj.from_local_data(x.shape, ret, dobj.distaxis(v))
return Field(self.domain, dobj.ensure_default_distributed(ret))
......@@ -64,16 +64,11 @@ class QHTOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
dom = self._domain[self._space]
x = x.val * dom.scalar_dvol
v = x.val * dom.scalar_dvol
n = self._domain.axes[self._space]
rng = n if mode == self.TIMES else reversed(n)
ax = dobj.distaxis(x)
for i in rng:
sl = (slice(None),)*i + (slice(1, None),)
if i == ax:
x = dobj.redistribute(x, nodist=(ax,))
tmp = dobj.local_data(x)
v, tmp = dobj.ensure_not_distributed(v, (i,))
tmp[sl] = hartley(tmp[sl], axes=(i,))
if i == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator
class RegriddingOperator(LinearOperator):
def __init__(self, domain, new_shape, space=0):
self._domain = DomainTuple.make(domain)
self._space = infer_space(self._domain, space)
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if len(new_shape) != len(dom.shape):
print(new_shape, dom.shape)
raise ValueError("Shape mismatch")
if any([a > b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must not be larger than old shape")
newdist = tuple(dom.distances[i]*dom.shape[i]/new_shape[i]
for i in range(len(dom.shape)))
tgt = RGSpace(new_shape, newdist)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
ndim = len(new_shape)
self._bindex = [None] * ndim
self._frac = [None] * ndim
for d in range(ndim):
tmp = np.arange(new_shape[d])*(newdist[d]/dom.distances[d])
self._bindex[d] = np.minimum(dom.shape[d]-2, tmp.astype(np.int))
self._frac[d] = tmp-self._bindex[d]
def apply(self, x, mode):
self._check_input(x, mode)
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
v, x = dobj.ensure_not_distributed(v, (d,))
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -37,14 +37,9 @@ class SymmetrizingOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
tmp = x.val.copy()
ax = dobj.distaxis(tmp)
v = x.val.copy()
for i in self._domain.axes[self._space]:
lead = (slice(None),)*i
if i == ax:
tmp = dobj.redistribute(tmp, nodist=(ax,))
tmp2 = dobj.local_data(tmp)
tmp2[lead+(slice(1, None),)] -= tmp2[lead+(slice(None, 0, -1),)]
if i == ax:
tmp = dobj.redistribute(tmp, dist=ax)
return Field(self.target, val=tmp)
v, loc = dobj.ensure_not_distributed(v, (i,))
loc[lead+(slice(1, None),)] -= loc[lead+(slice(None, 0, -1),)]
return Field(self.target, dobj.ensure_default_distributed(v))
......@@ -240,3 +240,12 @@ class Consistency_Tests(unittest.TestCase):
tgt = ift.DomainTuple.make(args[0])
op = ift.QHTOperator(tgt, args[1])
ift.extra.consistency_check(op, dtype, dtype)
@expand([[ift.RGSpace((13, 52, 40)), (4, 6, 25), None],
[ift.RGSpace((128, 128)), (45, 48), 0],
[ift.RGSpace(13), (7,), None],
[(ift.HPSpace(3), ift.RGSpace((12, 24),distances=0.3)),
(12, 12), 1]])
def testRegridding(self, domain, shape, space):
op = ift.RegriddingOperator(domain, shape, space)
ift.extra.consistency_check(op)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
from itertools import product
from test.common import expand
from numpy.testing import assert_allclose
import nifty5 as ift
class Regridding_Tests(unittest.TestCase):
@expand(
product([
ift.RGSpace(8, distances=12.9),
ift.RGSpace(59, distances=.24, harmonic=True),
ift.RGSpace([12, 3])
]))
def test_value(self, s):
Regrid = ift.RegriddingOperator(s, s.shape)
f = ift.from_random('normal', Regrid.domain)
assert_allclose(f.to_global_data(), Regrid(f).to_global_data())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment