Commit 498c13aa authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improvements

parent 10347246
Pipeline #21273 failed with stage
in 3 minutes and 57 seconds
......@@ -31,7 +31,7 @@ def local_shape(shape, distaxis):
class data_object(object):
def __init__(self, shape, data, distaxis):
"""Must not be called directly by users"""
self._shape = shape
self._shape = tuple(shape)
self._distaxis = distaxis
lshape = get_locshape(self._shape, self._distaxis)
self._data = data
......@@ -89,17 +89,30 @@ class data_object(object):
if axis is None:
res = np.array(getattr(self._data, op)())
if (self._distaxis==-1):
return res
return res[0]
res2 = np.empty(1,dtype=res.dtype)
comm.Allreduce(res,res2,mpiop)
return res2[0]
if self._distaxis in axis:
pass# reduce globally, redistribute the result along axis 0(?)
res = getattr(self._data, op)(axis=axis)
res2 = np.empty_like(res)
comm.Allreduce(res,res2,mpiop)
return from_global_data(res2, distaxis=0)
else:
# perform the contraction on the local data
data = getattr(self._data, op)(axis=axis)
#shp =
res = getattr(self._data, op)(axis=axis)
if self._distaxis == -1:
return from_global_data(res,distaxis=0)
shp = list(res.shape)
shift=0
for ax in axis:
if ax<self._distaxis:
shift+=1
print (axis,self._distaxis,shift)
shp[self._distaxis-shift] = self.shape[self._distaxis]
print (self.shape, shp)
return from_local_data(shp, res, self._distaxis-shift)
# check if the result is scalar or if a result_field must be constr.
if np.isscalar(data):
......
......@@ -240,12 +240,16 @@ def from_local_data (shape, arr, distaxis):
return data_object(arr)
def from_global_data (arr, distaxis):
def from_global_data (arr, distaxis=-1):
if distaxis!=-1:
raise NotImplementedError
return data_object(arr)
def to_global_data (arr):
return arr._data
def redistribute (arr, dist=None, nodist=None):
if dist is not None and dist!=-1:
raise NotImplementedError
......
......@@ -47,12 +47,16 @@ def from_local_data (shape, arr, distaxis):
return arr
def from_global_data (arr, distaxis):
def from_global_data (arr, distaxis=-1):
if distaxis!=-1:
raise NotImplementedError
return arr
def to_global_data (arr):
return arr
def redistribute (arr, dist=None, nodist=None):
if dist is not None and dist!=-1:
raise NotImplementedError
......
......@@ -58,7 +58,7 @@ class SmoothingOperator_Tests(unittest.TestCase):
op = ift.FFTSmoothingOperator(space, sigma=sigma)
fld = np.zeros(space.shape, dtype=np.float64)
fld[0] = 1.
rand1 = ift.Field(space, ift.dobj.from_global_data(fld, distaxis=-1))
rand1 = ift.Field(space, ift.dobj.from_global_data(fld))
tt1 = op.times(rand1)
assert_allclose(1, tt1.sum())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment