mav.h 17 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 *  This file is part of the MR utility library.
 *
 *  This code is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This code is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this code; if not, write to the Free Software
 *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
19
/* Copyright (C) 2019-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
20
21
   Author: Martin Reinecke */

Martin Reinecke's avatar
Martin Reinecke committed
22
23
24
25
26
#ifndef MRUTIL_MAV_H
#define MRUTIL_MAV_H

#include <cstdlib>
#include <array>
Martin Reinecke's avatar
Martin Reinecke committed
27
#include <vector>
28
#include <memory>
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
29
#include <numeric>
30
#include "mr_util/infra/error_handling.h"
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
31

Martin Reinecke's avatar
Martin Reinecke committed
32
33
namespace mr {

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
34
namespace detail_mav {
Martin Reinecke's avatar
Martin Reinecke committed
35
36
37

using namespace std;

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
template<typename T> class membuf
  {
  protected:
    using Tsp = shared_ptr<vector<T>>;
    Tsp ptr;
    const T *d;
    bool rw;

    membuf(const T *d_, membuf &other)
      : ptr(other.ptr), d(d_), rw(other.rw) {}
    membuf(const T *d_, const membuf &other)
      : ptr(other.ptr), d(d_), rw(false) {}

    // externally owned data pointer
    membuf(T *d_, bool rw_=false)
      : d(d_), rw(rw_) {}
    // externally owned data pointer, nonmodifiable
    membuf(const T *d_)
      : d(d_), rw(false) {}
    // allocate own memory
    membuf(size_t sz)
      : ptr(make_unique<vector<T>>(sz)), d(ptr->data()), rw(true) {}
    // share another memory buffer, but read-only
    membuf(const membuf &other)
      : ptr(other.ptr), d(other.d), rw(false) {}
#if defined(_MSC_VER)
    // MSVC is broken
    membuf(membuf &other)
      : ptr(other.ptr), d(other.d), rw(other.rw) {}
    membuf(membuf &&other)
      : ptr(move(other.ptr)), d(move(other.d)), rw(move(other.rw)) {}
#else
    // share another memory buffer, using the same read/write permissions
    membuf(membuf &other) = default;
    // take over another memory buffer
    membuf(membuf &&other) = default;
#endif

  public:
    // read/write access to element #i
    template<typename I> T &vraw(I i)
      {
      MR_assert(rw, "array is not writable");
      return const_cast<T *>(d)[i];
      }
    // read access to element #i
    template<typename I> const T &operator[](I i) const
      { return d[i]; }
    // read/write access to data area
    const T *data() const
      { return d; }
    // read access to data area
    T *vdata()
      {
      MR_assert(rw, "array is not writable");
      return const_cast<T *>(d);
      }
    bool writable() const { return rw; }
  };

Martin Reinecke's avatar
Martin Reinecke committed
98
99
class fmav_info
  {
Martin Reinecke's avatar
Martin Reinecke committed
100
101
102
103
  public:
    using shape_t = vector<size_t>;
    using stride_t = vector<ptrdiff_t>;

Martin Reinecke's avatar
Martin Reinecke committed
104
105
106
  protected:
    shape_t shp;
    stride_t str;
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
107
    size_t sz;
Martin Reinecke's avatar
Martin Reinecke committed
108

109
110
111
112
113
114
115
116
117
118
    static stride_t shape2stride(const shape_t &shp)
      {
      auto ndim = shp.size();
      stride_t res(ndim);
      res[ndim-1]=1;
      for (size_t i=2; i<=ndim; ++i)
        res[ndim-i] = res[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
      return res;
      }
    template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
Martin Reinecke's avatar
fix    
Martin Reinecke committed
119
      { return str[dim]*ptrdiff_t(n) + getIdx(dim+1, ns...); }
120
    ptrdiff_t getIdx(size_t dim, size_t n) const
Martin Reinecke's avatar
fix    
Martin Reinecke committed
121
      { return str[dim]*ptrdiff_t(n); }
Martin Reinecke's avatar
Martin Reinecke committed
122
123
124

  public:
    fmav_info(const shape_t &shape_, const stride_t &stride_)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
125
      : shp(shape_), str(stride_), sz(reduce(shp.begin(),shp.end(),size_t(1),multiplies<>()))
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
126
127
128
129
      {
      MR_assert(shp.size()>0, "at least 1D required");
      MR_assert(shp.size()==str.size(), "dimensions mismatch");
      }
Martin Reinecke's avatar
Martin Reinecke committed
130
    fmav_info(const shape_t &shape_)
131
      : fmav_info(shape_, shape2stride(shape_)) {}
Martin Reinecke's avatar
Martin Reinecke committed
132
    size_t ndim() const { return shp.size(); }
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
133
    size_t size() const { return sz; }
Martin Reinecke's avatar
Martin Reinecke committed
134
135
136
137
138
139
140
141
142
143
144
145
146
    const shape_t &shape() const { return shp; }
    size_t shape(size_t i) const { return shp[i]; }
    const stride_t &stride() const { return str; }
    const ptrdiff_t &stride(size_t i) const { return str[i]; }
    bool last_contiguous() const
      { return (str.back()==1); }
    bool contiguous() const
      {
      auto ndim = shp.size();
      ptrdiff_t stride=1;
      for (size_t i=0; i<ndim; ++i)
        {
        if (str[ndim-1-i]!=stride) return false;
Martin Reinecke's avatar
Martin Reinecke committed
147
        stride *= ptrdiff_t(shp[ndim-1-i]);
Martin Reinecke's avatar
Martin Reinecke committed
148
149
150
151
152
        }
      return true;
      }
    bool conformable(const fmav_info &other) const
      { return shp==other.shp; }
153
154
155
156
157
    template<typename... Ns> ptrdiff_t idx(Ns... ns) const
      {
      MR_assert(ndim()==sizeof...(ns), "incorrect number of indices");
      return getIdx(0, ns...);
      }
Martin Reinecke's avatar
Martin Reinecke committed
158
159
  };

160
template<size_t ndim> class mav_info
Martin Reinecke's avatar
Martin Reinecke committed
161
162
  {
  protected:
163
    static_assert(ndim>0, "at least 1D required");
Martin Reinecke's avatar
Martin Reinecke committed
164

165
166
    using shape_t = array<size_t, ndim>;
    using stride_t = array<ptrdiff_t, ndim>;
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    shape_t shp;
    stride_t str;
    size_t sz;

    static stride_t shape2stride(const shape_t &shp)
      {
      stride_t res;
      res[ndim-1]=1;
      for (size_t i=2; i<=ndim; ++i)
        res[ndim-i] = res[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
      return res;
      }
    template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
      { return str[dim]*n + getIdx(dim+1, ns...); }
    ptrdiff_t getIdx(size_t dim, size_t n) const
      { return str[dim]*n; }
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
184

Martin Reinecke's avatar
Martin Reinecke committed
185
  public:
186
    mav_info(const shape_t &shape_, const stride_t &stride_)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
187
      : shp(shape_), str(stride_), sz(reduce(shp.begin(),shp.end(),size_t(1),multiplies<>())) {}
188
189
190
191
192
193
194
195
196
197
    mav_info(const shape_t &shape_)
      : mav_info(shape_, shape2stride(shape_)) {}
    size_t size() const { return sz; }
    const shape_t &shape() const { return shp; }
    size_t shape(size_t i) const { return shp[i]; }
    const stride_t &stride() const { return str; }
    const ptrdiff_t &stride(size_t i) const { return str[i]; }
    bool last_contiguous() const
      { return (str.back()==1); }
    bool contiguous() const
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
198
      {
199
200
201
202
203
204
205
      ptrdiff_t stride=1;
      for (size_t i=0; i<ndim; ++i)
        {
        if (str[ndim-1-i]!=stride) return false;
        stride *= ptrdiff_t(shp[ndim-1-i]);
        }
      return true;
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
206
      }
207
208
209
210
211
    bool conformable(const mav_info &other) const
      { return shp==other.shp; }
    bool conformable(const shape_t &other) const
      { return shp==other; }
    template<typename... Ns> ptrdiff_t idx(Ns... ns) const
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
212
      {
213
214
      static_assert(ndim==sizeof...(ns), "incorrect number of indices");
      return getIdx(0, ns...);
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
215
      }
Martin Reinecke's avatar
Martin Reinecke committed
216
  };
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

class FmavIter
  {
  private:
    fmav_info::shape_t pos;
    fmav_info arr;
    ptrdiff_t p;
    size_t rem;

  public:
    FmavIter(const fmav_info &arr_)
      : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
    void advance()
      {
      --rem;
      for (int i_=int(pos.size())-1; i_>=0; --i_)
        {
        auto i = size_t(i_);
        p += arr.stride(i);
        if (++pos[i] < arr.shape(i))
          return;
        pos[i] = 0;
        p -= ptrdiff_t(arr.shape(i))*arr.stride(i);
        }
      }
    ptrdiff_t ofs() const { return p; }
    size_t remaining() const { return rem; }
  };


Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
248
249
// "mav" stands for "multidimensional array view"
template<typename T> class fmav: public fmav_info, public membuf<T>
Martin Reinecke's avatar
Martin Reinecke committed
250
  {
Martin Reinecke's avatar
Martin Reinecke committed
251
  protected:
Martin Reinecke's avatar
Martin Reinecke committed
252
253
    using tbuf = membuf<T>;
    using tinfo = fmav_info;
Martin Reinecke's avatar
Martin Reinecke committed
254

Martin Reinecke's avatar
Martin Reinecke committed
255
    auto subdata(const shape_t &i0, const shape_t &extent) const
Martin Reinecke's avatar
Martin Reinecke committed
256
      {
Martin Reinecke's avatar
Martin Reinecke committed
257
258
259
260
      auto ndim = tinfo::ndim();
      shape_t nshp(ndim);
      stride_t nstr(ndim);
      ptrdiff_t nofs;
Martin Reinecke's avatar
Martin Reinecke committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
      MR_assert(i0.size()==ndim, "bad domensionality");
      MR_assert(extent.size()==ndim, "bad domensionality");
      size_t n0=0;
      for (auto x:extent) if (x==0) ++n0;
      nofs=0;
      nshp.resize(ndim-n0);
      nstr.resize(ndim-n0);
      for (size_t i=0, i2=0; i<ndim; ++i)
        {
        MR_assert(i0[i]<shp[i], "bad subset");
        nofs+=i0[i]*str[i];
        if (extent[i]!=0)
          {
          MR_assert(i0[i]+extent[i2]<=shp[i], "bad subset");
          nshp[i2] = extent[i]; nstr[i2]=str[i];
          ++i2;
          }
        }
Martin Reinecke's avatar
Martin Reinecke committed
279
      return make_tuple(nshp, nstr, ndim);
Martin Reinecke's avatar
Martin Reinecke committed
280
281
      }

Martin Reinecke's avatar
Martin Reinecke committed
282
  public:
Martin Reinecke's avatar
Martin Reinecke committed
283
284
    using tbuf::vraw, tbuf::operator[], tbuf::vdata, tbuf::data;

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
285
    fmav(const T *d_, const shape_t &shp_, const stride_t &str_)
Martin Reinecke's avatar
Martin Reinecke committed
286
      : tinfo(shp_, str_), tbuf(d_) {}
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
287
    fmav(const T *d_, const shape_t &shp_)
Martin Reinecke's avatar
Martin Reinecke committed
288
      : tinfo(shp_), tbuf(d_) {}
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
289
    fmav(T *d_, const shape_t &shp_, const stride_t &str_, bool rw_)
Martin Reinecke's avatar
Martin Reinecke committed
290
      : tinfo(shp_, str_), tbuf(d_,rw_) {}
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
291
    fmav(T *d_, const shape_t &shp_, bool rw_)
Martin Reinecke's avatar
Martin Reinecke committed
292
      : tinfo(shp_), tbuf(d_,rw_) {}
293
    fmav(const shape_t &shp_)
Martin Reinecke's avatar
Martin Reinecke committed
294
295
296
297
298
      : tinfo(shp_), tbuf(size()) {}
    fmav(const T* d_, const tinfo &info)
      : tinfo(info), tbuf(d_) {}
    fmav(T* d_, const tinfo &info, bool rw_=false)
      : tinfo(info), tbuf(d_, rw_) {}
Martin Reinecke's avatar
Martin Reinecke committed
299
#if defined(_MSC_VER)
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
300
    // MSVC is broken
Martin Reinecke's avatar
Martin Reinecke committed
301
302
303
    fmav(const fmav &other) : tinfo(other), tbuf(other) {};
    fmav(fmav &other) : tinfo(other), tbuf(other) {}
    fmav(fmav &&other) : tinfo(other), tbuf(other) {}
Martin Reinecke's avatar
Martin Reinecke committed
304
#else
Martin Reinecke's avatar
Martin Reinecke committed
305
    fmav(const fmav &other) = default;
Martin Reinecke's avatar
Martin Reinecke committed
306
    fmav(fmav &other) = default;
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
307
    fmav(fmav &&other) = default;
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
308
#endif
Martin Reinecke's avatar
Martin Reinecke committed
309
310
311
312
313
314
315
316
    fmav(tbuf &buf, const shape_t &shp_, const stride_t &str_)
      : tinfo(shp_, str_), tbuf(buf) {}
    fmav(const tbuf &buf, const shape_t &shp_, const stride_t &str_)
      : tinfo(shp_, str_), tbuf(buf) {}
    fmav(const shape_t &shp_, const stride_t &str_, const T *d_, tbuf &buf)
      : tinfo(shp_, str_), tbuf(d_, buf) {}
    fmav(const shape_t &shp_, const stride_t &str_, const T *d_, const tbuf &buf)
      : tinfo(shp_, str_), tbuf(d_, buf) {}
317
318
319
320
321

    template<typename... Ns> const T &operator()(Ns... ns) const
      { return operator[](idx(ns...)); }
    template<typename... Ns> T &v(Ns... ns)
      { return vraw(idx(ns...)); }
Martin Reinecke's avatar
Martin Reinecke committed
322
323
324

    fmav subarray(const shape_t &i0, const shape_t &extent)
      {
Martin Reinecke's avatar
Martin Reinecke committed
325
      auto [nshp, nstr, nofs] = subdata(i0, extent);
326
      return fmav(nshp, nstr, tbuf::d+nofs, *this);
Martin Reinecke's avatar
Martin Reinecke committed
327
328
329
      }
    fmav subarray(const shape_t &i0, const shape_t &extent) const
      {
Martin Reinecke's avatar
Martin Reinecke committed
330
      auto [nshp, nstr, nofs] = subdata(i0, extent);
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
      return fmav(nshp, nstr, tbuf::d+nofs, *this);
      }
    vector<T> dump() const
      {
      FmavIter it(*this);
      vector<T> res(sz);
      for (size_t i=0; i<sz; ++i, it.advance())
        res[i] = operator[](it.ofs());
      return res;
      }
    void load (const vector<T> &v)
      {
      MR_assert(v.size()==sz, "bad input data size");
      FmavIter it(*this);
      for (size_t i=0; i<sz; ++i, it.advance())
        vraw(it.ofs()) = v[i];
Martin Reinecke's avatar
Martin Reinecke committed
347
      }
Martin Reinecke's avatar
Martin Reinecke committed
348
349
  };

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
350
template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membuf<T>
Martin Reinecke's avatar
Martin Reinecke committed
351
  {
352
//  static_assert((ndim>0) && (ndim<4), "only supports 1D, 2D, and 3D arrays");
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
353

Martin Reinecke's avatar
Martin Reinecke committed
354
  protected:
Martin Reinecke's avatar
Martin Reinecke committed
355
356
357
358
359
    using tinfo = mav_info<ndim>;
    using tbuf = membuf<T>;
    using typename tinfo::shape_t;
    using typename tinfo::stride_t;
    using tinfo::shp, tinfo::str;
Martin Reinecke's avatar
Martin Reinecke committed
360

361
362
363
364
365
366
367
368
369
370
371
372
    template<size_t idim, typename Func> void applyHelper(ptrdiff_t idx, Func func)
      {
      if constexpr (idim+1<ndim)
        for (size_t i=0; i<shp[idim]; ++i)
          applyHelper<idim+1, Func>(idx+i*str[idim], func);
      else
        {
        T *d2 = vdata();
        for (size_t i=0; i<shp[idim]; ++i)
          func(d2[idx+i*str[idim]]);
        }
      }
Martin Reinecke's avatar
Martin Reinecke committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    template<size_t idim, typename Func> void applyHelper(ptrdiff_t idx, Func func) const
      {
      if constexpr (idim+1<ndim)
        for (size_t i=0; i<shp[idim]; ++i)
          applyHelper<idim+1, Func>(idx+i*str[idim], func);
      else
        {
        const T *d2 = data();
        for (size_t i=0; i<shp[idim]; ++i)
          func(d2[idx+i*str[idim]]);
        }
      }
    template<size_t idim, typename T2, typename Func>
      void applyHelper(ptrdiff_t idx, ptrdiff_t idx2,
                       const mav<T2,ndim> &other, Func func)
      {
      if constexpr (idim==0)
        MR_assert(conformable(other), "dimension mismatch");
      if constexpr (idim+1<ndim)
        for (size_t i=0; i<shp[idim]; ++i)
          applyHelper<idim+1, T2, Func>(idx+i*str[idim],
                                        idx2+i*other.str[idim], other, func);
      else
        {
        T *d2 = vdata();
        const T2 *d3 = other.data();
        for (size_t i=0; i<shp[idim]; ++i)
          func(d2[idx+i*str[idim]],d3[idx2+i*other.str[idim]]);
        }
      }
Martin Reinecke's avatar
Martin Reinecke committed
403

Martin Reinecke's avatar
Martin Reinecke committed
404
    template<size_t nd2> auto subdata(const shape_t &i0, const shape_t &extent) const
Martin Reinecke's avatar
Martin Reinecke committed
405
      {
Martin Reinecke's avatar
Martin Reinecke committed
406
407
408
      array<size_t, nd2> nshp;
      array<ptrdiff_t, nd2> nstr;
      ptrdiff_t nofs;
Martin Reinecke's avatar
Martin Reinecke committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
      size_t n0=0;
      for (auto x:extent) if (x==0)++n0;
      MR_assert(n0+nd2==ndim, "bad extent");
      nofs=0;
      for (size_t i=0, i2=0; i<ndim; ++i)
        {
        MR_assert(i0[i]<shp[i], "bad subset");
        nofs+=i0[i]*str[i];
        if (extent[i]!=0)
          {
          MR_assert(i0[i]+extent[i2]<=shp[i], "bad subset");
          nshp[i2] = extent[i]; nstr[i2]=str[i];
          ++i2;
          }
        }
Martin Reinecke's avatar
Martin Reinecke committed
424
      return make_tuple(nshp, nstr, nofs);
Martin Reinecke's avatar
Martin Reinecke committed
425
426
      }

Martin Reinecke's avatar
Martin Reinecke committed
427
  public:
Martin Reinecke's avatar
Martin Reinecke committed
428
429
    using tbuf::vraw, tbuf::operator[], tbuf::vdata, tbuf::data;
    using tinfo::contiguous, tinfo::size, tinfo::idx, tinfo::conformable;
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
430

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
431
    mav(const T *d_, const shape_t &shp_, const stride_t &str_)
Martin Reinecke's avatar
Martin Reinecke committed
432
      : tinfo(shp_, str_), tbuf(d_) {}
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
433
    mav(T *d_, const shape_t &shp_, const stride_t &str_, bool rw_=false)
Martin Reinecke's avatar
Martin Reinecke committed
434
      : tinfo(shp_, str_), tbuf(d_, rw_) {}
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
435
    mav(const T *d_, const shape_t &shp_)
Martin Reinecke's avatar
Martin Reinecke committed
436
      : tinfo(shp_), tbuf(d_) {}
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
437
    mav(T *d_, const shape_t &shp_, bool rw_=false)
Martin Reinecke's avatar
Martin Reinecke committed
438
      : tinfo(shp_), tbuf(d_, rw_) {}
439
    mav(const array<size_t,ndim> &shp_)
Martin Reinecke's avatar
Martin Reinecke committed
440
      : tinfo(shp_), tbuf(size()) {}
Martin Reinecke's avatar
Martin Reinecke committed
441
#if defined(_MSC_VER)
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
442
    // MSVC is broken
Martin Reinecke's avatar
Martin Reinecke committed
443
444
445
    mav(const mav &other) : tinfo(other), tbuf(other) {}
    mav(mav &other): tinfo(other), tbuf(other) {}
    mav(mav &&other): tinfo(other), tbuf(other) {}
Martin Reinecke's avatar
Martin Reinecke committed
446
#else
Martin Reinecke's avatar
Martin Reinecke committed
447
    mav(const mav &other) = default;
Martin Reinecke's avatar
Martin Reinecke committed
448
    mav(mav &other) = default;
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
449
    mav(mav &&other) = default;
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
450
#endif
Martin Reinecke's avatar
Martin Reinecke committed
451
452
453
454
    mav(const shape_t &shp_, const stride_t &str_, const T *d_, membuf<T> &mb)
      : mav_info<ndim>(shp_, str_), membuf<T>(d_, mb) {}
    mav(const shape_t &shp_, const stride_t &str_, const T *d_, const membuf<T> &mb)
      : mav_info<ndim>(shp_, str_), membuf<T>(d_, mb) {}
Martin Reinecke's avatar
Martin Reinecke committed
455
456
    operator fmav<T>() const
      {
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
457
      return fmav<T>(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()});
Martin Reinecke's avatar
Martin Reinecke committed
458
      }
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
459
    operator fmav<T>()
Martin Reinecke's avatar
Martin Reinecke committed
460
      {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
461
      return fmav<T>(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()});
Martin Reinecke's avatar
Martin Reinecke committed
462
      }
Martin Reinecke's avatar
Martin Reinecke committed
463
464
465
466
    template<typename... Ns> const T &operator()(Ns... ns) const
      { return operator[](idx(ns...)); }
    template<typename... Ns> T &v(Ns... ns)
      { return vraw(idx(ns...)); }
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
467
    template<typename Func> void apply(Func func)
468
      {
Martin Reinecke's avatar
Martin Reinecke committed
469
470
      if (contiguous())
        {
471
        T *d2 = vdata();
Martin Reinecke's avatar
Martin Reinecke committed
472
        for (auto v=d2; v!=d2+size(); ++v)
Martin Reinecke's avatar
Martin Reinecke committed
473
474
475
          func(*v);
        return;
        }
476
      applyHelper<0,Func>(0,func);
Martin Reinecke's avatar
Martin Reinecke committed
477
      }
Martin Reinecke's avatar
Martin Reinecke committed
478
479
480
    template<typename T2, typename Func> void apply
      (const mav<T2, ndim> &other,Func func)
      { applyHelper<0,T2,Func>(0,0,other,func); }
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
481
482
    void fill(const T &val)
      { apply([val](T &v){v=val;}); }
Martin Reinecke's avatar
Martin Reinecke committed
483
484
    template<size_t nd2> mav<T,nd2> subarray(const shape_t &i0, const shape_t &extent)
      {
Martin Reinecke's avatar
Martin Reinecke committed
485
486
      auto [nshp, nstr, nofs] = subdata<nd2> (i0, extent);
      return mav<T,nd2> (nshp, nstr, tbuf::d+nofs, *this);
Martin Reinecke's avatar
Martin Reinecke committed
487
488
489
      }
    template<size_t nd2> mav<T,nd2> subarray(const shape_t &i0, const shape_t &extent) const
      {
Martin Reinecke's avatar
Martin Reinecke committed
490
491
      auto [nshp, nstr, nofs] = subdata<nd2> (i0, extent);
      return mav<T,nd2> (nshp, nstr, tbuf::d+nofs, *this);
Martin Reinecke's avatar
Martin Reinecke committed
492
      }
Martin Reinecke's avatar
Martin Reinecke committed
493
494
  };

Martin Reinecke's avatar
Martin Reinecke committed
495
496
497
498
499
500
template<typename T, size_t ndim> class MavIter
  {
  protected:
    fmav<T> mav;
    array<size_t, ndim> shp;
    array<ptrdiff_t, ndim> str;
Martin Reinecke's avatar
Martin Reinecke committed
501
    fmav_info::shape_t pos;
Martin Reinecke's avatar
Martin Reinecke committed
502
503
504
    ptrdiff_t idx_;
    bool done_;

Martin Reinecke's avatar
Martin Reinecke committed
505
506
507
508
509
    template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
      { return str[dim]*n + getIdx(dim+1, ns...); }
    ptrdiff_t getIdx(size_t dim, size_t n) const
      { return str[dim]*n; }

Martin Reinecke's avatar
Martin Reinecke committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
  public:
    MavIter(const fmav<T> &mav_)
      : mav(mav_), pos(mav.ndim()-ndim,0), idx_(0), done_(false)
      {
      for (size_t i=0; i<ndim; ++i)
        {
        shp[i] = mav.shape(mav.ndim()-ndim+i);
        str[i] = mav.stride(mav.ndim()-ndim+i);
        }
      }
    MavIter(fmav<T> &mav_)
      : mav(mav_), pos(mav.ndim()-ndim,0), idx_(0), done_(false)
      {
      for (size_t i=0; i<ndim; ++i)
        {
        shp[i] = mav.shape(mav.ndim()-ndim+i);
        str[i] = mav.stride(mav.ndim()-ndim+i);
        }
      }
    bool done() const
      { return done_; }
    void inc()
      {
      for (ptrdiff_t i=mav.ndim()-ndim-1; i>=0; --i)
        {
        idx_+=mav.stride(i);
        if (++pos[i]<mav.shape(i)) return;
        pos[i]=0;
        idx_-=mav.shape(i)*mav.stride(i);
        }
      done_=true;
      }
    size_t shape(size_t i) const { return shp[i]; }
Martin Reinecke's avatar
Martin Reinecke committed
543
544
545
546
547
548
    template<typename... Ns> ptrdiff_t idx(Ns... ns) const
      { return idx_ + getIdx(0, ns...); }
    template<typename... Ns> const T &operator()(Ns... ns) const
      { return mav[idx(ns...)]; }
    template<typename... Ns> T &v(Ns... ns)
      { return mav.vraw(idx(ns...)); }
Martin Reinecke's avatar
Martin Reinecke committed
549
550
  };

Martin Reinecke's avatar
Martin Reinecke committed
551
552
}

Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
553
using detail_mav::fmav_info;
Martin Reinecke's avatar
Martin Reinecke committed
554
using detail_mav::fmav;
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
555
using detail_mav::mav;
556
using detail_mav::FmavIter;
Martin Reinecke's avatar
Martin Reinecke committed
557
using detail_mav::MavIter;
Martin Reinecke's avatar
Martin Reinecke committed
558

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
559
}
Martin Reinecke's avatar
Martin Reinecke committed
560
561

#endif