Commit 56f93fd8 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fixes

parent 498c13aa
Pipeline #21276 failed with stage
in 4 minutes
...@@ -18,6 +18,8 @@ def shareRange(nwork, nshares, myshare): ...@@ -18,6 +18,8 @@ def shareRange(nwork, nshares, myshare):
return lo,hi return lo,hi
def get_locshape(shape, distaxis): def get_locshape(shape, distaxis):
if len(shape)==0:
distaxis = -1
if distaxis==-1: if distaxis==-1:
return shape return shape
if distaxis<0 or distaxis>=len(shape): if distaxis<0 or distaxis>=len(shape):
...@@ -32,6 +34,8 @@ class data_object(object): ...@@ -32,6 +34,8 @@ class data_object(object):
def __init__(self, shape, data, distaxis): def __init__(self, shape, data, distaxis):
"""Must not be called directly by users""" """Must not be called directly by users"""
self._shape = tuple(shape) self._shape = tuple(shape)
if len(self._shape)==0:
distaxis = -1
self._distaxis = distaxis self._distaxis = distaxis
lshape = get_locshape(self._shape, self._distaxis) lshape = get_locshape(self._shape, self._distaxis)
self._data = data self._data = data
...@@ -238,7 +242,7 @@ def vdot(a, b): ...@@ -238,7 +242,7 @@ def vdot(a, b):
tmp = np.vdot(a._data.ravel(), b._data.ravel()) tmp = np.vdot(a._data.ravel(), b._data.ravel())
res = np.empty(1,dtype=type(tmp)) res = np.empty(1,dtype=type(tmp))
comm.Allreduce(tmp,res,MPI.SUM) comm.Allreduce(tmp,res,MPI.SUM)
return res return res[0]
def _math_helper(x, function, out): def _math_helper(x, function, out):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment