Commit 073ce5ce authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fixes

parent 6108ab56
......@@ -91,10 +91,10 @@ class data_object(object):
if axis is None:
res = np.array(getattr(self._data, op)())
if (self._distaxis==-1):
return res[0]
res2 = np.empty(1,dtype=res.dtype)
return res[()]
res2 = np.empty((),dtype=res.dtype)
_comm.Allreduce(res,res2,mpiop)
return res2[0]
return res2[()]
if self._distaxis in axis:
res = getattr(self._data, op)(axis=axis)
......@@ -122,6 +122,10 @@ class data_object(object):
def sum(self, axis=None):
return self._contraction_helper("sum", MPI.SUM, axis)
def min(self, axis=None):
return self._contraction_helper("min", MPI.MIN, axis)
def max(self, axis=None):
return self._contraction_helper("max", MPI.MAX, axis)
# FIXME: to be improved!
def mean(self):
......@@ -239,10 +243,10 @@ def empty_like(a, dtype=None):
def vdot(a, b):
tmp = np.vdot(a._data, b._data)
res = np.empty(1,dtype=type(tmp))
tmp = np.array(np.vdot(a._data, b._data))
res = np.empty((),dtype=tmp.dtype)
_comm.Allreduce(tmp,res,MPI.SUM)
return res[0]
return res[()]
def _math_helper(x, function, out):
......@@ -364,7 +368,7 @@ def redistribute (arr, dist=None, nodist=None):
out = np.moveaxis(out, 0, arr._distaxis)
return from_global_data (out, distaxis=-1)
# real redistribution via Alltoallv
# temporary slow, but simple solution
# temporary slow, but simple solution for comparison purposes:
#return redistribute(redistribute(arr,dist=-1),dist=dist)
tmp = np.moveaxis(arr._data, (dist, arr._distaxis), (0, 1))
......
......@@ -48,7 +48,7 @@ if not special_hartley:
def _fill_array(tmp, res, axes):
if axes is None:
axes = range(a.ndim)
axes = range(tmp.ndim)
lastaxis = axes[-1]
ntmplast = tmp.shape[lastaxis]
slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
......@@ -60,7 +60,7 @@ def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis < len(a.shape) for axis in axes):
raise ValueError("Provided axes does not match array shape")
raise ValueError("Provided axes do not match array shape")
if issubclass(a.dtype.type, np.complexfloating):
raise TypeError("Hartley tansform requires real-valued arrays.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment