Commit cf2e08c2 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'array_ops_on_speed' into 'NIFTy_5'

Speed up DOFDistributor and ExpTransform

Closes #44

See merge request ift/nifty-dev!66
parents 516a415f edc2ed5e
......@@ -25,7 +25,7 @@ from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.dof_space import DOFSpace
from ..field import Field
from ..utilities import infer_space
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator
......@@ -116,7 +116,7 @@ class DOFDistributor(LinearOperator):
arr = x.local_data
arr = arr.reshape(self._pshape)
oarr = np.zeros(self._hshape, dtype=x.dtype)
np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
oarr = special_add_at(oarr, 1, self._dofdex, arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
res = Field.from_global_data(self._domain, oarr)
......
......@@ -27,13 +27,13 @@ from ..domains.power_space import PowerSpace
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
from ..utilities import infer_space, special_add_at
class ExpTransform(LinearOperator):
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._space = utilities.infer_space(self._target, space)
self._space = infer_space(self._target, space)
tgt = self._target[self._space]
if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
isinstance(tgt, PowerSpace)):
......@@ -112,8 +112,8 @@ class ExpTransform(LinearOperator):
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
np.add.at(xnew, idx + (self._bindex[d-d0],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d-d0]+1,), x * wgt)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
......
......@@ -30,7 +30,7 @@ from .compat import *
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product", "frozendict"]
"my_product", "frozendict", "special_add_at"]
def my_sum(terms):
......@@ -81,7 +81,7 @@ def get_slice_list(shape, axes):
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError("axes(axis) does not match shape.")
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_select = [0 if x in axes else 1 for x in range(len(shape))]
axes_iterables = \
[list(range(y)) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
......@@ -334,3 +334,25 @@ class frozendict(collections.Mapping):
h ^= hash((key, value))
self._hash = h
return self._hash
def special_add_at(a, axis, index, b):
if a.dtype != b.dtype:
raise TypeError("data type mismatch")
sz1 = int(np.prod(a.shape[:axis]))
sz3 = int(np.prod(a.shape[axis+1:]))
a2 = a.reshape([sz1, -1, sz3])
b2 = b.reshape([sz1, -1, sz3])
if np.issubdtype(a.dtype, np.complexfloating):
dt2 = a.real.dtype
a2 = a2.view(dt2)
b2 = b2.view(dt2)
sz3 *= 2
for i1 in range(sz1):
for i3 in range(sz3):
a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
minlength=a2.shape[1])
if np.issubdtype(a.dtype, np.complexfloating):
a2 = a2.view(a.dtype)
return a2.reshape(a.shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment