Commit 625370a6 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 558a3f53
......@@ -116,7 +116,7 @@ class DOFDistributor(LinearOperator):
arr = x.local_data
arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype)
oarr = special_add_at(arr, 1, self._dofdex, oarr)
oarr = special_add_at(oarr, 1, self._dofdex, arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
res = Field.from_global_data(self._domain, oarr)
......
......@@ -112,10 +112,10 @@ class ExpTransform(LinearOperator):
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = utilities.special_add_at(x*(1.-wgt), d,
self._bindex[d-d0], xnew)
xnew = utilities.special_add_at(x*wgt, d,
self._bindex[d-d0]+1, xnew)
xnew = utilities.special_add_at(xnew, d,
self._bindex[d-d0], x*(1.-wgt))
xnew = utilities.special_add_at(xnew, d,
self._bindex[d-d0]+1, x*wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
......
......@@ -336,21 +336,23 @@ class frozendict(collections.Mapping):
return self._hash
def special_add_at(inp, axis, index, out):
sz1=int(np.prod(inp.shape[:axis]))
sz3=int(np.prod(inp.shape[axis+1:]))
inp2 = inp.reshape([sz1,-1,sz3])
out2 = out.reshape([sz1,-1,sz3])
if not np.issubdtype(inp.dtype, np.complexfloating):
for i1 in range(sz1):
for i3 in range(sz3):
out2[i1,:,i3] += np.bincount(index, inp2[i1,:,i3],
minlength=out2.shape[1])
else:
for i1 in range(sz1):
for i3 in range(sz3):
out2[i1,:,i3].real += np.bincount(index, inp2[i1,:,i3].real,
minlength=out2.shape[1])
out2[i1,:,i3].imag += np.bincount(index, inp2[i1,:,i3].imag,
minlength=out2.shape[1])
return out2.reshape(out.shape)
def special_add_at(a, axis, index, b):
if a.dtype != b.dtype:
raise TypeError("data type mismatch")
sz1=int(np.prod(a.shape[:axis]))
sz3=int(np.prod(a.shape[axis+1:]))
a2 = a.reshape([sz1,-1,sz3])
b2 = b.reshape([sz1,-1,sz3])
if np.issubdtype(a.dtype, np.complexfloating):
dt2 = a.real.dtype
a2 = a2.view(dt2)
b2 = b2.view(dt2)
sz3 *= 2
for i1 in range(sz1):
for i3 in range(sz3):
a2[i1,:,i3] += np.bincount(index, b2[i1,:,i3],
minlength=a2.shape[1])
if np.issubdtype(a.dtype, np.complexfloating):
a2 = a2.view(a.dtype)
return a2.reshape(a.shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment