Commit edc2ed5e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

even more cosmetics

parent a2af7344
......@@ -27,13 +27,13 @@ from import PowerSpace
from import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
from ..utilities import infer_space, special_add_at
class ExpTransform(LinearOperator):
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._space = utilities.infer_space(self._target, space)
self._space = infer_space(self._target, space)
tgt = self._target[self._space]
if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
isinstance(tgt, PowerSpace)):
......@@ -112,10 +112,8 @@ class ExpTransform(LinearOperator):
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = utilities.special_add_at(xnew, d,
self._bindex[d-d0], x*(1.-wgt))
xnew = utilities.special_add_at(xnew, d,
self._bindex[d-d0]+1, x*wgt)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
......@@ -81,7 +81,7 @@ def get_slice_list(shape, axes):
if axes:
if not all(axis < len(shape) for axis in axes):
raise ValueError("axes(axis) does not match shape.")
axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
axes_select = [0 if x in axes else 1 for x in range(len(shape))]
axes_iterables = \
[list(range(y)) for x, y in enumerate(shape) if x not in axes]
for index in product(*axes_iterables):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment