diff --git a/nifty/data_objects/my_own_do.py b/nifty/data_objects/my_own_do.py index 3ae33ebe2414346799d9c606c011f61b4c8b3de0..8e9d86bf62e5e12398faafdb2d8f376f1d133bc8 100644 --- a/nifty/data_objects/my_own_do.py +++ b/nifty/data_objects/my_own_do.py @@ -236,3 +236,35 @@ def np_allreduce_sum(arr): def dist_axis(arr): return -1 + + +def from_local_data (shape, arr, dist_axis): + if dist_axis!=-1: + raise NotImplementedError + if shape!=arr.shape: + raise ValueError + return data_object(arr) + + +def from_global_data (arr, dist_axis): + if dist_axis!=-1: + raise NotImplementedError + if shape!=arr.shape: + raise ValueError + return data_object(arr) + + +def redistribute (arr, dist=None, nodist=None): + if dist is not None and dist!=-1: + raise NotImplementedError + return arr + + +def default_dist_axis(): + return -1 + + +def local_shape(glob_shape, dist_axis): + if dist_axis!=-1: + raise NotImplementedError + return glob_shape diff --git a/nifty/data_objects/numpy_do.py b/nifty/data_objects/numpy_do.py index cb076fe1bc5da4413130e8ae1b3e8d5aa0e696f4..de2eb0ccdf7a44d17dbfeb6d4b020811f8b4dbb6 100644 --- a/nifty/data_objects/numpy_do.py +++ b/nifty/data_objects/numpy_do.py @@ -43,3 +43,35 @@ def np_allreduce_sum(arr): def dist_axis(arr): return -1 + + +def from_local_data (shape, arr, dist_axis): + if dist_axis!=-1: + raise NotImplementedError + if shape!=arr.shape: + raise ValueError + return arr + + +def from_global_data (arr, dist_axis): + if dist_axis!=-1: + raise NotImplementedError + if shape!=arr.shape: + raise ValueError + return arr + + +def redistribute (arr, dist=None, nodist=None): + if dist is not None and dist!=-1: + raise NotImplementedError + return arr + + +def default_dist_axis(): + return -1 + + +def local_shape(glob_shape, dist_axis): + if dist_axis!=-1: + raise NotImplementedError + return glob_shape diff --git a/nifty/dobj.py b/nifty/dobj.py index ec0d151cc6b122ac17f9cfb008ba6a6354b590a4..cb3af9f93c9eda671cad98b9d2e074f04fcb3da2 100644 --- a/nifty/dobj.py +++ b/nifty/dobj.py @@ -1,2 +1,2 @@ -#from .data_objects.my_own_do import * -from .data_objects.numpy_do import * +from .data_objects.my_own_do import * +#from .data_objects.numpy_do import * diff --git a/nifty/operators/fft_operator_support.py b/nifty/operators/fft_operator_support.py index 998fca9ead14f3b2a8103222f2861f48c40ad827..5ea130396ae891c2115b3f37584c73d77e175da4 100644 --- a/nifty/operators/fft_operator_support.py +++ b/nifty/operators/fft_operator_support.py @@ -114,16 +114,23 @@ class SphericalTransformation(Transformation): def transform(self, x): axes = x.domain.axes[self.space] - if dobj.dist_axis(x.val) in axes: - raise NotImplementedError + axis = axes[0] + tval = x.val + if dobj.dist_axis(tval) == axis: + tval = dobj.redistribute(tval, nodist=(axis,)) + distaxis = dobj.dist_axis(tval) + p2h = x.domain == self.pdom - idat = dobj.local_data(x.val) + idat = dobj.local_data(tval) if p2h: - res = Field(self.hdom, dtype=x.dtype) - odat = dobj.local_data(res.val) + odat = np.empty(dobj.local_shape(self.hdom.shape, dist_axis=distaxis), dtype=x.dtype) for slice in utilities.get_slice_list(idat.shape, axes): odat[slice] = self._slice_p2h(idat[slice]) + odat = dobj.from_local_data(self.hdom.shape, odat, distaxis) + if distaxis!= dobj.dist_axis(x): + odat = dobj.redistribute(odat, dist=distaxis) + return Field(self.hdom, odat) else: res = Field(self.pdom, dtype=x.dtype) odat = dobj.local_data(res.val) diff --git a/nifty/operators/power_projection_operator.py b/nifty/operators/power_projection_operator.py index 022eada5de7a5e8b183a2bd5cd633997d9d27eb9..8a817538beec9f19baabb2024133c286d0ca1b08 100644 --- a/nifty/operators/power_projection_operator.py +++ b/nifty/operators/power_projection_operator.py @@ -52,16 +52,30 @@ class PowerProjectionOperator(LinearOperator): def _times(self, x): pindex = self._target[self._space].pindex - pindex = pindex.reshape((1, pindex.size, 1)) - arr = x.weight(1).val.reshape( - x.domain.collapsed_shape_for_domain(self._space)) - out = dobj.zeros(self._target.collapsed_shape_for_domain(self._space), - dtype=x.dtype) - out = dobj.to_ndarray(out) - np.add.at(out, (slice(None), dobj.to_ndarray(pindex.ravel()), slice(None)), dobj.to_ndarray(arr)) - out = dobj.from_ndarray(out) - return Field(self._target, out.reshape(self._target.shape))\ - .weight(-1, spaces=self._space) + res = Field.zeros(self._target, dtype=x.dtype) + if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space + pindex = dobj.local_data(pindex) + pindex.reshape((1, pindex.size, 1)) + arr = dobj.local_data(x.weight(1).val) + firstaxis = x.domain.axes[self._space][0] + lastaxis = x.domain.axes[self._space][-1] + presize = np.prod(arr.shape[0:firstaxis], dtype=np.int) + postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int) + arr = arr.reshape((presize,pindex.size,postsize)) + oarr = dobj.local_data(res.val).reshape((presize,-1,postsize)) + np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr) + else: + pindex = dobj.to_ndarray(pindex) + pindex.reshape((1, pindex.size, 1)) + arr = dobj.local_data(x.weight(1).val) + firstaxis = x.domain.axes[self._space][0] + lastaxis = x.domain.axes[self._space][-1] + presize = np.prod(arr.shape[0:firstaxis], dtype=np.int) + postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int) + arr = arr.reshape((presize,pindex.size,postsize)) + oarr = dobj.local_data(res.val).reshape((presize,-1,postsize)) + np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr) + return res.weight(-1, spaces=self._space) def _adjoint_times(self, x): pindex = self._target[self._space].pindex diff --git a/nifty/spaces/power_space.py b/nifty/spaces/power_space.py index d012dc2b5e90ca00ea2b4b45aa7c140172937bac..05fe9ce6a4589a8f6aaacecfe276e5a8020f77dd 100644 --- a/nifty/spaces/power_space.py +++ b/nifty/spaces/power_space.py @@ -143,14 +143,16 @@ class PowerSpace(Space): else: tbb = binbounds locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val)) - temp_pindex = dobj.create_from_template(k_length_array.val, local_data=locdat, - dtype=locdat.dtype) + temp_pindex = dobj.create_from_template( + k_length_array.val, local_data=locdat, dtype=locdat.dtype) nbin = len(tbb) - temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(), minlength=nbin) + temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(), + minlength=nbin) temp_rho = dobj.np_allreduce_sum(temp_rho) assert not (temp_rho == 0).any(), "empty bins detected" temp_k_lengths = np.bincount(dobj.local_data(temp_pindex).ravel(), - weights=dobj.local_data(k_length_array.val).ravel(), minlength=nbin) + weights=dobj.local_data(k_length_array.val).ravel(), + minlength=nbin) temp_k_lengths = dobj.np_allreduce_sum(temp_k_lengths) / temp_rho temp_dvol = temp_rho*pdvol self._powerIndexCache[key] = (binbounds,