Commit a2af7344 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cosmetics

parent 625370a6
...@@ -339,10 +339,10 @@ class frozendict(collections.Mapping): ...@@ -339,10 +339,10 @@ class frozendict(collections.Mapping):
def special_add_at(a, axis, index, b): def special_add_at(a, axis, index, b):
if a.dtype != b.dtype: if a.dtype != b.dtype:
raise TypeError("data type mismatch") raise TypeError("data type mismatch")
sz1=int(np.prod(a.shape[:axis])) sz1 = int(np.prod(a.shape[:axis]))
sz3=int(np.prod(a.shape[axis+1:])) sz3 = int(np.prod(a.shape[axis+1:]))
a2 = a.reshape([sz1,-1,sz3]) a2 = a.reshape([sz1, -1, sz3])
b2 = b.reshape([sz1,-1,sz3]) b2 = b.reshape([sz1, -1, sz3])
if np.issubdtype(a.dtype, np.complexfloating): if np.issubdtype(a.dtype, np.complexfloating):
dt2 = a.real.dtype dt2 = a.real.dtype
a2 = a2.view(dt2) a2 = a2.view(dt2)
...@@ -350,8 +350,8 @@ def special_add_at(a, axis, index, b): ...@@ -350,8 +350,8 @@ def special_add_at(a, axis, index, b):
sz3 *= 2 sz3 *= 2
for i1 in range(sz1): for i1 in range(sz1):
for i3 in range(sz3): for i3 in range(sz3):
a2[i1,:,i3] += np.bincount(index, b2[i1,:,i3], a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
minlength=a2.shape[1]) minlength=a2.shape[1])
if np.issubdtype(a.dtype, np.complexfloating): if np.issubdtype(a.dtype, np.complexfloating):
a2 = a2.view(a.dtype) a2 = a2.view(a.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment