diff --git a/nifty5/data_objects/distributed_do.py b/nifty5/data_objects/distributed_do.py index 393e22a94cf68318cf7749a9558c8f514856462e..4c53c94363c596227d67fb596b6d3589985723d8 100644 --- a/nifty5/data_objects/distributed_do.py +++ b/nifty5/data_objects/distributed_do.py @@ -59,6 +59,8 @@ class data_object(object): distaxis = -1 self._distaxis = distaxis self._data = data + if local_shape(self._shape, self._distaxis) != self._data.shape: + raise ValueError("shape mismatch") # def _sanity_checks(self): # # check whether the distaxis is consistent diff --git a/nifty5/operators/exp_transform.py b/nifty5/operators/exp_transform.py index d9c846519b9e7e821bfb15566cba1760a5d75786..4f56a07d1815ae24f8066661a2ac1147631e9fb3 100644 --- a/nifty5/operators/exp_transform.py +++ b/nifty5/operators/exp_transform.py @@ -4,6 +4,7 @@ from ..domain_tuple import DomainTuple from ..domains import PowerSpace, RGSpace from ..field import Field from .linear_operator import LinearOperator +from .. import dobj class ExpTransform(LinearOperator): @@ -27,7 +28,7 @@ class ExpTransform(LinearOperator): for i in range(ndim): if isinstance(target, RGSpace): rng = np.arange(target.shape[i]) - tmp = np.minimum(rng, target.shape[i] + 1 - rng) + tmp = np.minimum(rng, target.shape[i]+1-rng) k_array = tmp * target.distances[i] else: k_array = target.k_lengths @@ -42,8 +43,8 @@ class ExpTransform(LinearOperator): # Save t_min for later t_mins[i] = t_min - bindistances[i] = (t_max - t_min) / (dof[i] - 1) - coord = np.append(0., 1. + (log_k_array - t_min) / bindistances[i]) + bindistances[i] = (t_max-t_min) / (dof[i]-1) + coord = np.append(0., 1. + (log_k_array-t_min) / bindistances[i]) self._bindex[i] = np.floor(coord).astype(int) # Interpolated value is computed via @@ -52,7 +53,7 @@ class ExpTransform(LinearOperator): self._frac[i] = coord - self._bindex[i] from ..domains import LogRGSpace - log_space = LogRGSpace(2 * dof + 1, bindistances, + log_space = LogRGSpace(2*dof+1, bindistances, t_mins, harmonic=False) self._target = DomainTuple.make(target) self._domain = DomainTuple.make(log_space) @@ -67,30 +68,34 @@ class ExpTransform(LinearOperator): def apply(self, x, mode): self._check_input(x, mode) - x = x.to_global_data() + x = x.val + ax = dobj.distaxis(x) ndim = len(self.target.shape) - idx = () + curshp = list(self._dom(mode).shape) for d in range(ndim): - fst_dims = (1,) * d - lst_dims = (1,) * (ndim - d - 1) - wgt = self._frac[d].reshape(fst_dims + (-1,) + lst_dims) + idx = (slice(None,),) * d + wgt = self._frac[d].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1)) + + if d == ax: + x = dobj.redistribute(x, nodist=(ax,)) + curax = dobj.distaxis(x) + x = dobj.local_data(x) - # ADJOINT_TIMES if mode == self.ADJOINT_TIMES: shp = list(x.shape) shp[d] = self._tgt(mode).shape[d] xnew = np.zeros(shp, dtype=x.dtype) - np.add.at(xnew, idx + (self._bindex[d],), x * (1. - wgt)) - np.add.at(xnew, idx + (self._bindex[d] + 1,), x * wgt) - - # TIMES - else: - xnew = x[idx + (self._bindex[d],)] * (1. - wgt) - xnew += x[idx + (self._bindex[d] + 1,)] * wgt - - x = xnew - idx = (slice(None),) + idx - return Field.from_global_data(self._tgt(mode), x) + np.add.at(xnew, idx + (self._bindex[d],), x * (1.-wgt)) + np.add.at(xnew, idx + (self._bindex[d]+1,), x * wgt) + else: # TIMES + xnew = x[idx + (self._bindex[d],)] * (1.-wgt) + xnew += x[idx + (self._bindex[d]+1,)] * wgt + + curshp[d] = self._tgt(mode).shape[d] + x = dobj.from_local_data(curshp, xnew, distaxis=curax) + if d == ax: + x = dobj.redistribute(x, dist=ax) + return Field(self._tgt(mode), val=x) @property def capability(self): diff --git a/nifty5/operators/qht_operator.py b/nifty5/operators/qht_operator.py index 6149d83de51db5b2fbb163f87225808cd73257d6..cf9c3d8a7c1d9fb4d76d39c849e7cae9c186a769 100644 --- a/nifty5/operators/qht_operator.py +++ b/nifty5/operators/qht_operator.py @@ -35,18 +35,18 @@ class QHTOperator(LinearOperator): x = x.val * self.domain[0].scalar_dvol() n = len(self.domain[0].shape) rng = range(n) if mode == self.TIMES else reversed(range(n)) + ax = dobj.distaxis(x) + globshape = x.shape for i in rng: sl = (slice(None),)*i + (slice(1, None),) - if i == dobj.distaxis(x): - x = dobj.redistribute(x, nodist=(i,)) - ax = dobj.distaxis(x) - x = dobj.local_data(x) - x[sl] = hartley(x[sl], axes=(i,)) - x = dobj.from_local_data(x.shape, x, distaxis=ax) - x = dobj.redistribute(x, dist=i) - else: - x[sl] = hartley(x[sl], axes=(i,)) - + if i == ax: + x = dobj.redistribute(x, nodist=(ax,)) + curax = dobj.distaxis(x) + x = dobj.local_data(x) + x[sl] = hartley(x[sl], axes=(i,)) + x = dobj.from_local_data(globshape, x, distaxis=curax) + if i == ax: + x = dobj.redistribute(x, dist=ax) return Field(self._tgt(mode), val=x) @property diff --git a/nifty5/operators/symmetrizing_operator.py b/nifty5/operators/symmetrizing_operator.py index 0edc3c2ccdacb0fe32f250aa3d28b93a38afc678..a0e429f802671a2f7eeae89901f4eb7f3e7d9184 100644 --- a/nifty5/operators/symmetrizing_operator.py +++ b/nifty5/operators/symmetrizing_operator.py @@ -1,6 +1,7 @@ from ..domain_tuple import DomainTuple from ..field import Field from .endomorphic_operator import EndomorphicOperator +from .. import dobj class SymmetrizingOperator(EndomorphicOperator): @@ -14,12 +15,20 @@ class SymmetrizingOperator(EndomorphicOperator): def apply(self, x, mode): self._check_input(x, mode) - # FIXME Not efficient with MPI - tmp = x.to_global_data().copy() + tmp = x.copy().val + ax = dobj.distaxis(tmp) + globshape = tmp.shape for i in range(self._ndim): lead = (slice(None),)*i + if i == ax: + tmp = dobj.redistribute(tmp, nodist=(ax,)) + curax = dobj.distaxis(tmp) + tmp = dobj.local_data(tmp) tmp[lead + (slice(1, None),)] -= tmp[lead + (slice(None, 0, -1),)] - return Field.from_global_data(self.target, tmp) + tmp = dobj.from_local_data(globshape, tmp, distaxis=curax) + if i == ax: + tmp = dobj.redistribute(tmp, dist=ax) + return Field(self.target, val=tmp) @property def capability(self):