distributed_do.py 15.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
20
21
import numpy as np
from .random import Random
from mpi4py import MPI
22
import sys
Martin Reinecke's avatar
fix  
Martin Reinecke committed
23
from functools import reduce
24

Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
25
26
27
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
rank = _comm.Get_rank()
Martin Reinecke's avatar
Martin Reinecke committed
28
master = (rank == 0)
29
30


Martin Reinecke's avatar
Martin Reinecke committed
31
32
33
34
def is_numpy():
    return False


Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
35
def _shareSize(nwork, nshares, myshare):
Martin Reinecke's avatar
Martin Reinecke committed
36
    return (nwork//nshares) + int(myshare < nwork % nshares)
Martin Reinecke's avatar
Martin Reinecke committed
37

Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
38
39

def _shareRange(nwork, nshares, myshare):
Martin Reinecke's avatar
Martin Reinecke committed
40
41
    nbase = nwork//nshares
    additional = nwork % nshares
Martin Reinecke's avatar
Martin Reinecke committed
42
    lo = myshare*nbase + min(myshare, additional)
Martin Reinecke's avatar
Martin Reinecke committed
43
    hi = lo + nbase + int(myshare < additional)
Martin Reinecke's avatar
Martin Reinecke committed
44
45
    return lo, hi

46

47
def local_shape(shape, distaxis=0):
Martin Reinecke's avatar
Martin Reinecke committed
48
    if len(shape) == 0 or distaxis == -1:
49
        return shape
Martin Reinecke's avatar
Martin Reinecke committed
50
51
    shape2 = list(shape)
    shape2[distaxis] = _shareSize(shape[distaxis], ntask, rank)
52
53
    return tuple(shape2)

Martin Reinecke's avatar
Martin Reinecke committed
54

55
56
class data_object(object):
    def __init__(self, shape, data, distaxis):
Martin Reinecke's avatar
Martin Reinecke committed
57
        self._shape = tuple(shape)
Martin Reinecke's avatar
Martin Reinecke committed
58
        if len(self._shape) == 0:
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
59
            distaxis = -1
60
61
62
        self._distaxis = distaxis
        self._data = data

Martin Reinecke's avatar
Martin Reinecke committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#     def _sanity_checks(self):
#         # check whether the distaxis is consistent
#         if self._distaxis < -1 or self._distaxis >= len(self._shape):
#             raise ValueError
#         itmp = np.array(self._distaxis)
#         otmp = np.empty(ntask, dtype=np.int)
#         _comm.Allgather(itmp, otmp)
#         if np.any(otmp != self._distaxis):
#             raise ValueError
#         # check whether the global shape is consistent
#         itmp = np.array(self._shape)
#         otmp = np.empty((ntask, len(self._shape)), dtype=np.int)
#         _comm.Allgather(itmp, otmp)
#         for i in range(ntask):
#             if np.any(otmp[i, :] != self._shape):
#                 raise ValueError
#         # check shape of local data
#         if self._distaxis < 0:
#             if self._data.shape != self._shape:
#                 raise ValueError
#         else:
#             itmp = np.array(self._shape)
#             itmp[self._distaxis] = _shareSize(self._shape[self._distaxis],
#                                               ntask, rank)
#             if np.any(self._data.shape != itmp):
#                 raise ValueError
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    @property
    def dtype(self):
        return self._data.dtype

    @property
    def shape(self):
        return self._shape

    @property
    def size(self):
        return np.prod(self._shape)

    @property
    def real(self):
Martin Reinecke's avatar
Martin Reinecke committed
104
        return data_object(self._shape, self._data.real, self._distaxis)
105
106
107

    @property
    def imag(self):
Martin Reinecke's avatar
Martin Reinecke committed
108
        return data_object(self._shape, self._data.imag, self._distaxis)
109

Martin Reinecke's avatar
Martin Reinecke committed
110
111
112
113
114
115
    def conj(self):
        return data_object(self._shape, self._data.conj(), self._distaxis)

    def conjugate(self):
        return data_object(self._shape, self._data.conjugate(), self._distaxis)

Martin Reinecke's avatar
Martin Reinecke committed
116
    def _contraction_helper(self, op, mpiop, axis):
117
        if axis is not None:
Martin Reinecke's avatar
Martin Reinecke committed
118
            if len(axis) == len(self._data.shape):
119
120
                axis = None
        if axis is None:
Martin Reinecke's avatar
Martin Reinecke committed
121
            res = np.array(getattr(self._data, op)())
Martin Reinecke's avatar
Martin Reinecke committed
122
            if (self._distaxis == -1):
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
123
                return res[()]
Martin Reinecke's avatar
Martin Reinecke committed
124
125
            res2 = np.empty((), dtype=res.dtype)
            _comm.Allreduce(res, res2, mpiop)
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
126
            return res2[()]
127
128

        if self._distaxis in axis:
Martin Reinecke's avatar
Martin Reinecke committed
129
130
            res = getattr(self._data, op)(axis=axis)
            res2 = np.empty_like(res)
Martin Reinecke's avatar
Martin Reinecke committed
131
            _comm.Allreduce(res, res2, mpiop)
Martin Reinecke's avatar
Martin Reinecke committed
132
            return from_global_data(res2, distaxis=0)
133
        else:
Martin Reinecke's avatar
Martin Reinecke committed
134
            # perform the contraction on the local data
Martin Reinecke's avatar
Martin Reinecke committed
135
136
            res = getattr(self._data, op)(axis=axis)
            if self._distaxis == -1:
Martin Reinecke's avatar
Martin Reinecke committed
137
                return from_global_data(res, distaxis=0)
Martin Reinecke's avatar
Martin Reinecke committed
138
            shp = list(res.shape)
Martin Reinecke's avatar
Martin Reinecke committed
139
            shift = 0
Martin Reinecke's avatar
Martin Reinecke committed
140
            for ax in axis:
Martin Reinecke's avatar
Martin Reinecke committed
141
142
                if ax < self._distaxis:
                    shift += 1
Martin Reinecke's avatar
Martin Reinecke committed
143
144
            shp[self._distaxis-shift] = self.shape[self._distaxis]
            return from_local_data(shp, res, self._distaxis-shift)
145
146
147

    def sum(self, axis=None):
        return self._contraction_helper("sum", MPI.SUM, axis)
Martin Reinecke's avatar
Martin Reinecke committed
148

149
150
151
    def prod(self, axis=None):
        return self._contraction_helper("prod", MPI.PROD, axis)

Martin Reinecke's avatar
fixes  
Martin Reinecke committed
152
153
    def min(self, axis=None):
        return self._contraction_helper("min", MPI.MIN, axis)
Martin Reinecke's avatar
Martin Reinecke committed
154

Martin Reinecke's avatar
fixes  
Martin Reinecke committed
155
156
    def max(self, axis=None):
        return self._contraction_helper("max", MPI.MAX, axis)
157

158
159
160
161
162
163
    def mean(self, axis=None):
        if axis is None:
            sz = self.size
        else:
            sz = reduce(lambda x, y: x*y, [self.shape[i] for i in axis])
        return self.sum(axis)/sz
Martin Reinecke's avatar
Martin Reinecke committed
164

165
166
    def std(self, axis=None):
        return np.sqrt(self.var(axis))
Martin Reinecke's avatar
Martin Reinecke committed
167

Martin Reinecke's avatar
Martin Reinecke committed
168
    # FIXME: to be improved!
169
170
171
    def var(self, axis=None):
        if axis is not None and len(axis) != len(self.shape):
            raise ValueError("functionality not yet supported")
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
172
173
        return (abs(self-self.mean())**2).mean()

174
    def _binary_helper(self, other, op):
Martin Reinecke's avatar
Martin Reinecke committed
175
        a = self
176
        if isinstance(other, data_object):
Martin Reinecke's avatar
Martin Reinecke committed
177
            b = other
178
179
180
181
            if a._shape != b._shape:
                raise ValueError("shapes are incompatible.")
            if a._distaxis != b._distaxis:
                raise ValueError("distributions are incompatible.")
Martin Reinecke's avatar
Martin Reinecke committed
182
183
            a = a._data
            b = b._data
Martin Reinecke's avatar
Martin Reinecke committed
184
185
186
187
        elif np.isscalar(other):
            a = a._data
            b = other
        elif isinstance(other, np.ndarray):
Martin Reinecke's avatar
Martin Reinecke committed
188
            a = a._data
189
            b = other
Martin Reinecke's avatar
Martin Reinecke committed
190
191
        else:
            return NotImplemented
192
193

        tval = getattr(a, op)(b)
Martin Reinecke's avatar
Martin Reinecke committed
194
195
196
197
        if tval is a:
            return self
        else:
            return data_object(self._shape, tval, self._distaxis)
198
199

    def __neg__(self):
Martin Reinecke's avatar
Martin Reinecke committed
200
        return data_object(self._shape, -self._data, self._distaxis)
201
202

    def __abs__(self):
203
        return data_object(self._shape, abs(self._data), self._distaxis)
204
205

    def all(self):
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
206
        return self.sum() == self.size
207
208

    def any(self):
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
209
        return self.sum() != 0
210

Martin Reinecke's avatar
fixes  
Martin Reinecke committed
211
212
    def fill(self, value):
        self._data.fill(value)
213

214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
for op in ["__add__", "__radd__", "__iadd__",
           "__sub__", "__rsub__", "__isub__",
           "__mul__", "__rmul__", "__imul__",
           "__div__", "__rdiv__", "__idiv__",
           "__truediv__", "__rtruediv__", "__itruediv__",
           "__floordiv__", "__rfloordiv__", "__ifloordiv__",
           "__pow__", "__rpow__", "__ipow__",
           "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
    def func(op):
        def func2(self, other):
            return self._binary_helper(other, op=op)
        return func2
    setattr(data_object, op, func(op))

Martin Reinecke's avatar
Martin Reinecke committed
229

Martin Reinecke's avatar
Martin Reinecke committed
230
def full(shape, fill_value, dtype=None, distaxis=0):
Martin Reinecke's avatar
Martin Reinecke committed
231
232
    return data_object(shape, np.full(local_shape(shape, distaxis),
                                      fill_value, dtype), distaxis)
233
234


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
235
def empty(shape, dtype=None, distaxis=0):
Martin Reinecke's avatar
Martin Reinecke committed
236
237
    return data_object(shape, np.empty(local_shape(shape, distaxis),
                                       dtype), distaxis)
238
239


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
240
def zeros(shape, dtype=None, distaxis=0):
Martin Reinecke's avatar
Martin Reinecke committed
241
242
    return data_object(shape, np.zeros(local_shape(shape, distaxis), dtype),
                       distaxis)
243
244


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
245
def ones(shape, dtype=None, distaxis=0):
Martin Reinecke's avatar
Martin Reinecke committed
246
247
    return data_object(shape, np.ones(local_shape(shape, distaxis), dtype),
                       distaxis)
248
249
250
251
252
253
254


def empty_like(a, dtype=None):
    return data_object(np.empty_like(a._data, dtype))


def vdot(a, b):
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
255
    tmp = np.array(np.vdot(a._data, b._data))
Martin Reinecke's avatar
Martin Reinecke committed
256
257
    res = np.empty((), dtype=tmp.dtype)
    _comm.Allreduce(tmp, res, MPI.SUM)
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
258
    return res[()]
259
260
261


def _math_helper(x, function, out):
262
    function = getattr(np, function)
263
264
265
266
    if out is not None:
        function(x._data, out=out._data)
        return out
    else:
Martin Reinecke's avatar
Martin Reinecke committed
267
        return data_object(x.shape, function(x._data), x._distaxis)
268
269


270
_current_module = sys.modules[__name__]
Martin Reinecke's avatar
Martin Reinecke committed
271

272
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
273
274
275
276
277
    def func(f):
        def func2(x, out=None):
            return _math_helper(x, f, out)
        return func2
    setattr(_current_module, f, func(f))
278
279


Martin Reinecke's avatar
Martin Reinecke committed
280
281
282
283
284
285
286
287
288
289
290
291
def from_object(object, dtype, copy, set_locked):
    if dtype is None:
        dtype = object.dtype
    dtypes_equal = dtype == object.dtype
    if set_locked and dtypes_equal and locked(object):
        return object
    if not dtypes_equal and not copy:
        raise ValueError("cannot change data type without copying")
    if set_locked and not copy:
        raise ValueError("cannot lock object without copying")
    data = np.array(object._data, dtype=dtype, copy=copy)
    if set_locked:
Martin Reinecke's avatar
fix  
Martin Reinecke committed
292
        data.flags.writeable = False
Martin Reinecke's avatar
Martin Reinecke committed
293
    return data_object(object._shape, data, distaxis=object._distaxis)
294
295


Martin Reinecke's avatar
Martin Reinecke committed
296
297
# This function draws all random numbers on all tasks, to produce the same
# array independent on the number of tasks
Martin Reinecke's avatar
Martin Reinecke committed
298
299
300
# MR FIXME: depending on what is really wanted/needed (i.e. same result
# independent of number of tasks, performance etc.) we need to adjust the
# algorithm.
Martin Reinecke's avatar
Martin Reinecke committed
301
def from_random(random_type, shape, dtype=np.float64, **kwargs):
302
    generator_function = getattr(Random, random_type)
Martin Reinecke's avatar
Martin Reinecke committed
303
304
305
306
307
308
309
    for i in range(ntask):
        lshape = list(shape)
        lshape[0] = _shareSize(shape[0], ntask, i)
        ldat = generator_function(dtype=dtype, shape=lshape, **kwargs)
        if i == rank:
            outdat = ldat
    return from_local_data(shape, outdat, distaxis=0)
310

Martin Reinecke's avatar
Martin Reinecke committed
311

Martin Reinecke's avatar
Martin Reinecke committed
312
313
314
315
def local_data(arr):
    return arr._data


316
317
def ibegin_from_shape(glob_shape, distaxis=0):
    res = [0] * len(glob_shape)
Martin Reinecke's avatar
Martin Reinecke committed
318
    if distaxis < 0:
319
320
321
322
323
        return res
    res[distaxis] = _shareRange(glob_shape[distaxis], ntask, rank)[0]
    return tuple(res)


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
324
325
def ibegin(arr):
    res = [0] * arr._data.ndim
Martin Reinecke's avatar
Martin Reinecke committed
326
    res[arr._distaxis] = _shareRange(arr._shape[arr._distaxis], ntask, rank)[0]
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
327
    return tuple(res)
Martin Reinecke's avatar
Martin Reinecke committed
328
329


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
330
331
def np_allreduce_sum(arr):
    res = np.empty_like(arr)
Martin Reinecke's avatar
Martin Reinecke committed
332
    _comm.Allreduce(arr, res, MPI.SUM)
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
333
    return res
Martin Reinecke's avatar
Martin Reinecke committed
334
335


336
337
338
339
340
341
def np_allreduce_min(arr):
    res = np.empty_like(arr)
    _comm.Allreduce(arr, res, MPI.MIN)
    return res


Martin Reinecke's avatar
Martin Reinecke committed
342
343
344
345
def distaxis(arr):
    return arr._distaxis


Martin Reinecke's avatar
Martin Reinecke committed
346
def from_local_data(shape, arr, distaxis=0):
Martin Reinecke's avatar
Martin Reinecke committed
347
348
349
    return data_object(shape, arr, distaxis)


350
351
352
def from_global_data(arr, sum_up=False, distaxis=0):
    if sum_up:
        arr = np_allreduce_sum(arr)
Martin Reinecke's avatar
Martin Reinecke committed
353
    if distaxis == -1:
Martin Reinecke's avatar
Martin Reinecke committed
354
        return data_object(arr.shape, arr, distaxis)
Martin Reinecke's avatar
Martin Reinecke committed
355
    lo, hi = _shareRange(arr.shape[distaxis], ntask, rank)
Martin Reinecke's avatar
Martin Reinecke committed
356
    sl = [slice(None)]*len(arr.shape)
Martin Reinecke's avatar
Martin Reinecke committed
357
    sl[distaxis] = slice(lo, hi)
Martin Reinecke's avatar
Martin Reinecke committed
358
359
360
    return data_object(arr.shape, arr[sl], distaxis)


Martin Reinecke's avatar
Martin Reinecke committed
361
362
def to_global_data(arr):
    if arr._distaxis == -1:
Martin Reinecke's avatar
fixes  
Martin Reinecke committed
363
364
365
366
367
        return arr._data
    tmp = redistribute(arr, dist=-1)
    return tmp._data


Martin Reinecke's avatar
Martin Reinecke committed
368
def redistribute(arr, dist=None, nodist=None):
Martin Reinecke's avatar
Martin Reinecke committed
369
370
371
    if dist is not None:
        if nodist is not None:
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
372
        if dist == arr._distaxis:
Martin Reinecke's avatar
Martin Reinecke committed
373
374
375
376
377
378
            return arr
    else:
        if nodist is None:
            raise ValueError
        if arr._distaxis not in nodist:
            return arr
Martin Reinecke's avatar
Martin Reinecke committed
379
        dist = -1
Martin Reinecke's avatar
Martin Reinecke committed
380
381
        for i in range(len(arr.shape)):
            if i not in nodist:
Martin Reinecke's avatar
Martin Reinecke committed
382
                dist = i
Martin Reinecke's avatar
Martin Reinecke committed
383
                break
Martin Reinecke's avatar
Martin Reinecke committed
384

Martin Reinecke's avatar
Martin Reinecke committed
385
    if arr._distaxis == -1:  # all data available, just pick the proper subset
386
        return from_global_data(arr._data, distaxis=dist)
Martin Reinecke's avatar
Martin Reinecke committed
387
    if dist == -1:  # gather all data on all tasks
Martin Reinecke's avatar
Martin Reinecke committed
388
        tmp = np.moveaxis(arr._data, arr._distaxis, 0)
Martin Reinecke's avatar
Martin Reinecke committed
389
390
        slabsize = np.prod(tmp.shape[1:])*tmp.itemsize
        sz = np.empty(ntask, dtype=np.int)
Martin Reinecke's avatar
Martin Reinecke committed
391
        for i in range(ntask):
Martin Reinecke's avatar
Martin Reinecke committed
392
393
394
395
            sz[i] = slabsize*_shareSize(arr.shape[arr._distaxis], ntask, i)
        disp = np.empty(ntask, dtype=np.int)
        disp[0] = 0
        disp[1:] = np.cumsum(sz[:-1])
Martin Reinecke's avatar
Martin Reinecke committed
396
        tmp = np.require(tmp, requirements="C")
Martin Reinecke's avatar
Martin Reinecke committed
397
398
        out = np.empty(arr.size, dtype=arr.dtype)
        _comm.Allgatherv(tmp, [out, sz, disp, MPI.BYTE])
Martin Reinecke's avatar
Martin Reinecke committed
399
400
401
402
        shp = np.array(arr._shape)
        shp[1:arr._distaxis+1] = shp[0:arr._distaxis]
        shp[0] = arr.shape[arr._distaxis]
        out = out.reshape(shp)
Martin Reinecke's avatar
Martin Reinecke committed
403
        out = np.moveaxis(out, 0, arr._distaxis)
Martin Reinecke's avatar
Martin Reinecke committed
404
        return from_global_data(out, distaxis=-1)
Martin Reinecke's avatar
Martin Reinecke committed
405

Martin Reinecke's avatar
Martin Reinecke committed
406
    # real redistribution via Alltoallv
Martin Reinecke's avatar
Martin Reinecke committed
407
    ssz0 = arr._data.size//arr.shape[dist]
Martin Reinecke's avatar
Martin Reinecke committed
408
    ssz = np.empty(ntask, dtype=np.int)
Martin Reinecke's avatar
Martin Reinecke committed
409
410
411
    rszall = arr.size//arr.shape[dist]*_shareSize(arr.shape[dist], ntask, rank)
    rbuf = np.empty(rszall, dtype=arr.dtype)
    rsz0 = rszall//arr.shape[arr._distaxis]
Martin Reinecke's avatar
Martin Reinecke committed
412
    rsz = np.empty(ntask, dtype=np.int)
Martin Reinecke's avatar
Martin Reinecke committed
413
414
415
416
417
418
419
420
421
422
423
424
    if dist == 0:  # shortcut possible
        sbuf = np.ascontiguousarray(arr._data)
        for i in range(ntask):
            lo, hi = _shareRange(arr.shape[dist], ntask, i)
            ssz[i] = ssz0*(hi-lo)
            rsz[i] = rsz0*_shareSize(arr.shape[arr._distaxis], ntask, i)
    else:
        sbuf = np.empty(arr._data.size, dtype=arr.dtype)
        sslice = [slice(None)]*arr._data.ndim
        ofs = 0
        for i in range(ntask):
            lo, hi = _shareRange(arr.shape[dist], ntask, i)
Martin Reinecke's avatar
Martin Reinecke committed
425
            sslice[dist] = slice(lo, hi)
Martin Reinecke's avatar
Martin Reinecke committed
426
427
428
429
430
431
            ssz[i] = ssz0*(hi-lo)
            sbuf[ofs:ofs+ssz[i]] = arr._data[sslice].flat
            ofs += ssz[i]
            rsz[i] = rsz0*_shareSize(arr.shape[arr._distaxis], ntask, i)
    ssz *= arr._data.itemsize
    rsz *= arr._data.itemsize
Martin Reinecke's avatar
Martin Reinecke committed
432
433
    sdisp = np.append(0, np.cumsum(ssz[:-1]))
    rdisp = np.append(0, np.cumsum(rsz[:-1]))
Martin Reinecke's avatar
Martin Reinecke committed
434
435
    s_msg = [sbuf, (ssz, sdisp), MPI.BYTE]
    r_msg = [rbuf, (rsz, rdisp), MPI.BYTE]
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
436
    _comm.Alltoallv(s_msg, r_msg)
Martin Reinecke's avatar
Martin Reinecke committed
437
    del sbuf  # free memory
Martin Reinecke's avatar
Martin Reinecke committed
438
439
440
441
442
443
444
445
446
    if arr._distaxis == 0:
        rbuf = rbuf.reshape(local_shape(arr.shape, dist))
        arrnew = from_local_data(arr.shape, rbuf, distaxis=dist)
    else:
        arrnew = empty(arr.shape, dtype=arr.dtype, distaxis=dist)
        rslice = [slice(None)]*arr._data.ndim
        ofs = 0
        for i in range(ntask):
            lo, hi = _shareRange(arr.shape[arr._distaxis], ntask, i)
Martin Reinecke's avatar
Martin Reinecke committed
447
            rslice[arr._distaxis] = slice(lo, hi)
Martin Reinecke's avatar
Martin Reinecke committed
448
449
450
451
            sz = rsz[i]//arr._data.itemsize
            arrnew._data[rslice].flat = rbuf[ofs:ofs+sz]
            ofs += sz
    return arrnew
Martin Reinecke's avatar
Martin Reinecke committed
452
453


Martin Reinecke's avatar
Martin Reinecke committed
454
455
def transpose(arr):
    if len(arr.shape) != 2 or arr._distaxis != 0:
Martin Reinecke's avatar
Martin Reinecke committed
456
        raise ValueError("bad input")
Martin Reinecke's avatar
Martin Reinecke committed
457
458
459
460
461
462
463
464
465
466
467
    ssz0 = arr._data.size//arr.shape[1]
    ssz = np.empty(ntask, dtype=np.int)
    rszall = arr.size//arr.shape[1]*_shareSize(arr.shape[1], ntask, rank)
    rbuf = np.empty(rszall, dtype=arr.dtype)
    rsz0 = rszall//arr.shape[0]
    rsz = np.empty(ntask, dtype=np.int)
    sbuf = np.empty(arr._data.size, dtype=arr.dtype)
    ofs = 0
    for i in range(ntask):
        lo, hi = _shareRange(arr.shape[1], ntask, i)
        ssz[i] = ssz0*(hi-lo)
Martin Reinecke's avatar
Martin Reinecke committed
468
        sbuf[ofs:ofs+ssz[i]] = arr._data[:, lo:hi].flat
Martin Reinecke's avatar
Martin Reinecke committed
469
470
471
472
473
474
475
476
477
478
479
480
        ofs += ssz[i]
        rsz[i] = rsz0*_shareSize(arr.shape[0], ntask, i)
    ssz *= arr._data.itemsize
    rsz *= arr._data.itemsize
    sdisp = np.append(0, np.cumsum(ssz[:-1]))
    rdisp = np.append(0, np.cumsum(rsz[:-1]))
    s_msg = [sbuf, (ssz, sdisp), MPI.BYTE]
    r_msg = [rbuf, (rsz, rdisp), MPI.BYTE]
    _comm.Alltoallv(s_msg, r_msg)
    del sbuf  # free memory
    arrnew = empty((arr.shape[1], arr.shape[0]), dtype=arr.dtype, distaxis=0)
    ofs = 0
Martin Reinecke's avatar
Martin Reinecke committed
481
    sz2 = _shareSize(arr.shape[1], ntask, rank)
Martin Reinecke's avatar
Martin Reinecke committed
482
483
484
    for i in range(ntask):
        lo, hi = _shareRange(arr.shape[0], ntask, i)
        sz = rsz[i]//arr._data.itemsize
Martin Reinecke's avatar
Martin Reinecke committed
485
        arrnew._data[:, lo:hi] = rbuf[ofs:ofs+sz].reshape(hi-lo, sz2).T
Martin Reinecke's avatar
Martin Reinecke committed
486
487
488
489
        ofs += sz
    return arrnew


Martin Reinecke's avatar
Martin Reinecke committed
490
491
def default_distaxis():
    return 0
492
493
494
495
496
497
498
499


def lock(arr):
    arr._data.flags.writeable = False


def locked(arr):
    return not arr._data.flags.writeable