Commit 928cd935 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks for add.at

parent e7ff7772
......@@ -25,7 +25,7 @@ from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.dof_space import DOFSpace
from ..field import Field
from ..utilities import infer_space
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator
......@@ -116,7 +116,7 @@ class DOFDistributor(LinearOperator):
arr = x.local_data
arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype)
np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
oarr = special_add_at(arr, 1, self._dofdex, oarr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
res = Field.from_global_data(self._domain, oarr)
......
......@@ -112,8 +112,10 @@ class ExpTransform(LinearOperator):
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
np.add.at(xnew, idx + (self._bindex[d-d0],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d-d0]+1,), x * wgt)
xnew = utilities.special_add_at(x*(1.-wgt), d,
self._bindex[d-d0], xnew)
xnew = utilities.special_add_at(x*wgt, d,
self._bindex[d-d0]+1, xnew)
else: # TIMES
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
......
......@@ -30,7 +30,7 @@ from .compat import *
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product", "frozendict"]
"my_product", "frozendict", "special_add_at"]
def my_sum(terms):
......@@ -334,3 +334,24 @@ class frozendict(collections.Mapping):
h ^= hash((key, value))
self._hash = h
return self._hash
def special_add_at(inp, axis, index, out):
sz1=int(np.prod(inp.shape[:axis]))
sz3=int(np.prod(inp.shape[axis+1:]))
print(inp.shape, sz1, sz3)
inp2 = inp.reshape([sz1,-1,sz3])
out2 = out.reshape([sz1,-1,sz3])
if not np.issubdtype(inp.dtype, np.complexfloating):
for i1 in range(sz1):
for i3 in range(sz3):
out2[i1,:,i3] += np.bincount(index, inp2[i1,:,i3],
minlength=out2.shape[1])
else:
for i1 in range(sz1):
for i3 in range(sz3):
out2[i1,:,i3].real += np.bincount(index, inp2[i1,:,i3].real,
minlength=out2.shape[1])
out2[i1,:,i3].imag += np.bincount(index, inp2[i1,:,i3].imag,
minlength=out2.shape[1])
return out2.reshape(out.shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment