diff --git a/nifty/data_objects/distributed_do.py b/nifty/data_objects/distributed_do.py index 7ab10f45c3b7688e1b61a78a1bee2a1e612b2a96..6af7b8397239b4e85fe1a65da104ccc486799160 100644 --- a/nifty/data_objects/distributed_do.py +++ b/nifty/data_objects/distributed_do.py @@ -31,7 +31,7 @@ def local_shape(shape, distaxis): class data_object(object): def __init__(self, shape, data, distaxis): """Must not be called directly by users""" - self._shape = shape + self._shape = tuple(shape) self._distaxis = distaxis lshape = get_locshape(self._shape, self._distaxis) self._data = data @@ -89,17 +89,30 @@ class data_object(object): if axis is None: res = np.array(getattr(self._data, op)()) if (self._distaxis==-1): - return res + return res[0] res2 = np.empty(1,dtype=res.dtype) comm.Allreduce(res,res2,mpiop) return res2[0] if self._distaxis in axis: - pass# reduce globally, redistribute the result along axis 0(?) + res = getattr(self._data, op)(axis=axis) + res2 = np.empty_like(res) + comm.Allreduce(res,res2,mpiop) + return from_global_data(res2, distaxis=0) else: # perform the contraction on the local data - data = getattr(self._data, op)(axis=axis) - #shp = + res = getattr(self._data, op)(axis=axis) + if self._distaxis == -1: + return from_global_data(res,distaxis=0) + shp = list(res.shape) + shift=0 + for ax in axis: + if ax<self._distaxis: + shift+=1 + print (axis,self._distaxis,shift) + shp[self._distaxis-shift] = self.shape[self._distaxis] + print (self.shape, shp) + return from_local_data(shp, res, self._distaxis-shift) # check if the result is scalar or if a result_field must be constr. if np.isscalar(data): diff --git a/nifty/data_objects/my_own_do.py b/nifty/data_objects/my_own_do.py index 0fbf4a146184773e3334ae9af1f1a62d36fd0fe2..bbfc9a2f097ba240719e594c1fda271893d0942a 100644 --- a/nifty/data_objects/my_own_do.py +++ b/nifty/data_objects/my_own_do.py @@ -240,12 +240,16 @@ def from_local_data (shape, arr, distaxis): return data_object(arr) -def from_global_data (arr, distaxis): +def from_global_data (arr, distaxis=-1): if distaxis!=-1: raise NotImplementedError return data_object(arr) +def to_global_data (arr): + return arr._data + + def redistribute (arr, dist=None, nodist=None): if dist is not None and dist!=-1: raise NotImplementedError diff --git a/nifty/data_objects/numpy_do.py b/nifty/data_objects/numpy_do.py index 558ffb3731ca4941768198f1ca16ba6e3db6b7e2..c465d43ba9ca355f243d74261721865e69397d6c 100644 --- a/nifty/data_objects/numpy_do.py +++ b/nifty/data_objects/numpy_do.py @@ -47,12 +47,16 @@ def from_local_data (shape, arr, distaxis): return arr -def from_global_data (arr, distaxis): +def from_global_data (arr, distaxis=-1): if distaxis!=-1: raise NotImplementedError return arr +def to_global_data (arr): + return arr + + def redistribute (arr, dist=None, nodist=None): if dist is not None and dist!=-1: raise NotImplementedError diff --git a/test/test_operators/test_smoothing_operator.py b/test/test_operators/test_smoothing_operator.py index 236e87dae70d3005c5f98f43227a6d00f06ecfa8..79ad0c311aa096d0f5994c60e0c1231c667a5d5c 100644 --- a/test/test_operators/test_smoothing_operator.py +++ b/test/test_operators/test_smoothing_operator.py @@ -58,7 +58,7 @@ class SmoothingOperator_Tests(unittest.TestCase): op = ift.FFTSmoothingOperator(space, sigma=sigma) fld = np.zeros(space.shape, dtype=np.float64) fld[0] = 1. - rand1 = ift.Field(space, ift.dobj.from_global_data(fld, distaxis=-1)) + rand1 = ift.Field(space, ift.dobj.from_global_data(fld)) tt1 = op.times(rand1) assert_allclose(1, tt1.sum())