Commit 2c13c8f0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent f9c22026
......@@ -43,7 +43,7 @@ class DomainDistributor(LinearOperator):
shp = []
for i, tgt in enumerate(self._target):
tmp = tgt.shape if i > 0 else tgt.local_shape
shp += tmp if i in self._spaces else(1,)*len(tgt.shape)
shp += tmp if i in self._spaces else (1,)*len(tgt.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape)
return Field.from_local_data(self._target, ldat)
else:
......
......@@ -108,6 +108,7 @@ class ExpTransform(LinearOperator):
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
......@@ -117,7 +118,7 @@ class ExpTransform(LinearOperator):
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
......@@ -125,6 +126,6 @@ class ExpTransform(LinearOperator):
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
curshp[d] = xnew.shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
......@@ -31,25 +49,22 @@ class FieldZeroPadder(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
axbefore = self._target.axes[self._space][0]
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
if mode == self.ADJOINT_TIMES:
newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
sl = tuple(slice(0, shp_out[axis]) for axis in axes)
newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl]
else:
newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
sl = tuple(slice(0, shp_in[axis]) for axis in axes)
newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x)
newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
if dax in axes:
newarr = dobj.redistribute(newarr, dist=dax)
return Field(self._tgt(mode), val=newarr)
v = x.val
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
v, x = dobj.ensure_not_distributed(v, (d,))
if mode == self.TIMES:
shp = list(x.shape)
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew[idx + (slice(0, x.shape[d]),)] = x
else: # ADJOINT_TIMES
xnew = x[idx + (slice(0, tgtshp[d]),)]
curshp[d] = xnew.shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -53,7 +53,6 @@ class RegriddingOperator(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES
ndim = len(new_shape)
bindistances = np.empty(ndim)
self._bindex = [None] * ndim
self._frac = [None] * ndim
for d in range(ndim):
......@@ -66,6 +65,7 @@ class RegriddingOperator(LinearOperator):
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
......@@ -75,7 +75,7 @@ class RegriddingOperator(LinearOperator):
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
......@@ -83,6 +83,6 @@ class RegriddingOperator(LinearOperator):
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
curshp[d] = xnew.shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment