mav.h 8.86 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/*
 *  This file is part of the MR utility library.
 *
 *  This code is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This code is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this code; if not, write to the Free Software
 *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */

/* Copyright (C) 2019 Max-Planck-Society
   Author: Martin Reinecke */

Martin Reinecke's avatar
Martin Reinecke committed
22 23 24 25 26
#ifndef MRUTIL_MAV_H
#define MRUTIL_MAV_H

#include <cstdlib>
#include <array>
Martin Reinecke's avatar
Martin Reinecke committed
27
#include <vector>
28
#include <memory>
Martin Reinecke's avatar
Martin Reinecke committed
29
#include "mr_util/error_handling.h"
Martin Reinecke's avatar
Martin Reinecke committed
30 31 32

namespace mr {

Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
33
namespace detail_mav {
Martin Reinecke's avatar
Martin Reinecke committed
34 35 36

using namespace std;

Martin Reinecke's avatar
Martin Reinecke committed
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
using shape_t = vector<size_t>;
using stride_t = vector<ptrdiff_t>;

class fmav_info
  {
  protected:
    shape_t shp;
    stride_t str;

    static size_t prod(const shape_t &shape)
      {
      size_t res=1;
      for (auto sz: shape)
        res*=sz;
      return res;
      }

  public:
    fmav_info(const shape_t &shape_, const stride_t &stride_)
      : shp(shape_), str(stride_)
      { MR_assert(shp.size()==str.size(), "dimensions mismatch"); }
    fmav_info(const shape_t &shape_)
      : shp(shape_), str(shape_.size())
      {
      auto ndim = shp.size();
      str[ndim-1]=1;
      for (size_t i=2; i<=ndim; ++i)
        str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
      }
    size_t ndim() const { return shp.size(); }
    size_t size() const { return prod(shp); }
    const shape_t &shape() const { return shp; }
    size_t shape(size_t i) const { return shp[i]; }
    const stride_t &stride() const { return str; }
    const ptrdiff_t &stride(size_t i) const { return str[i]; }
    bool last_contiguous() const
      { return (str.back()==1); }
    bool contiguous() const
      {
      auto ndim = shp.size();
      ptrdiff_t stride=1;
      for (size_t i=0; i<ndim; ++i)
        {
        if (str[ndim-1-i]!=stride) return false;
Martin Reinecke's avatar
Martin Reinecke committed
81
        stride *= ptrdiff_t(shp[ndim-1-i]);
Martin Reinecke's avatar
Martin Reinecke committed
82 83 84 85 86 87 88
        }
      return true;
      }
    bool conformable(const fmav_info &other) const
      { return shp==other.shp; }
  };

89 90
template<typename T, size_t ndim> class cmav;

Martin Reinecke's avatar
Martin Reinecke committed
91
// "mav" stands for "multidimensional array view"
Martin Reinecke's avatar
Martin Reinecke committed
92 93 94
template<typename T> class cfmav: public fmav_info
  {
  protected:
95 96
    using Tsp = shared_ptr<vector<T>>;
    Tsp ptr;
Martin Reinecke's avatar
Martin Reinecke committed
97 98 99 100 101 102 103
    T *d;

  public:
    cfmav(const T *d_, const shape_t &shp_, const stride_t &str_)
      : fmav_info(shp_, str_), d(const_cast<T *>(d_)) {}
    cfmav(const T *d_, const shape_t &shp_)
      : fmav_info(shp_), d(const_cast<T *>(d_)) {}
104 105 106
    cfmav(const shape_t &shp_)
      : fmav_info(shp_), ptr(make_unique<vector<T>>(size())),
        d(const_cast<T *>(ptr->data())) {}
Martin Reinecke's avatar
Martin Reinecke committed
107 108
    cfmav(const T* d_, const fmav_info &info)
      : fmav_info(info), d(const_cast<T *>(d_)) {}
109 110 111 112 113
    cfmav(const cfmav &other) = default;
    cfmav(cfmav &&other) = default;
// Not for public use!
    cfmav(const T *d_, const Tsp &p, const shape_t &shp_, const stride_t &str_)
      : fmav_info(shp_, str_), ptr(p), d(const_cast<T *>(d_)) {}
Martin Reinecke's avatar
Martin Reinecke committed
114 115 116 117 118
    template<typename I> const T &operator[](I i) const
      { return d[i]; }
    const T *data() const
      { return d; }
  };
119

Martin Reinecke's avatar
Martin Reinecke committed
120 121 122
template<typename T> class fmav: public cfmav<T>
  {
  protected:
123
    using Tsp = shared_ptr<vector<T>>;
Martin Reinecke's avatar
Martin Reinecke committed
124 125 126 127 128 129 130 131 132 133
    using parent = cfmav<T>;
    using parent::d;
    using parent::shp;
    using parent::str;

  public:
    fmav(T *d_, const shape_t &shp_, const stride_t &str_)
      : parent(d_, shp_, str_) {}
    fmav(T *d_, const shape_t &shp_)
      : parent(d_, shp_) {}
134 135
    fmav(const shape_t &shp_)
      : parent(shp_) {}
Martin Reinecke's avatar
Martin Reinecke committed
136 137
    fmav(T* d_, const fmav_info &info)
      : parent(d_, info) {}
138 139 140 141 142 143
    fmav(const fmav &other) = default;
    fmav(fmav &&other) = default;
// Not for public use!
    fmav(T *d_, const Tsp &p, const shape_t &shp_, const stride_t &str_)
      : parent(d_, p, shp_, str_) {}

Martin Reinecke's avatar
Martin Reinecke committed
144 145 146 147 148 149 150 151 152 153 154
    template<typename I> T &operator[](I i) const
      { return d[i]; }
    using parent::shape;
    using parent::stride;
    T *data() const
      { return d; }
    using parent::last_contiguous;
    using parent::contiguous;
    using parent::conformable;
  };

Martin Reinecke's avatar
Martin Reinecke committed
155
template<typename T, size_t ndim> class cmav
Martin Reinecke's avatar
Martin Reinecke committed
156
  {
157
  static_assert((ndim>0) && (ndim<4), "only supports 1D, 2D, and 3D arrays");
Martin Reinecke's avatar
Martin Reinecke committed
158

Martin Reinecke's avatar
Martin Reinecke committed
159
  protected:
160
    using Tsp = shared_ptr<vector<T>>;
Martin Reinecke's avatar
Martin Reinecke committed
161 162
    array<size_t, ndim> shp;
    array<ptrdiff_t, ndim> str;
163 164
    Tsp ptr;
    T *d;
Martin Reinecke's avatar
Martin Reinecke committed
165 166

  public:
Martin Reinecke's avatar
Martin Reinecke committed
167
    cmav(const T *d_, const array<size_t,ndim> &shp_,
Martin Reinecke's avatar
Martin Reinecke committed
168
        const array<ptrdiff_t,ndim> &str_)
169
      : shp(shp_), str(str_), d(const_cast<T *>(d_)) {}
Martin Reinecke's avatar
Martin Reinecke committed
170
    cmav(const T *d_, const array<size_t,ndim> &shp_)
171 172 173 174 175 176 177 178
      : shp(shp_), d(const_cast<T *>(d_))
      {
      str[ndim-1]=1;
      for (size_t i=2; i<=ndim; ++i)
        str[ndim-i] = str[ndim-i+1]*shp[ndim-i+1];
      }
    cmav(const array<size_t,ndim> &shp_)
      : shp(shp_), ptr(make_unique<vector<T>>(size())), d(ptr->data())
Martin Reinecke's avatar
Martin Reinecke committed
179 180
      {
      str[ndim-1]=1;
Martin Reinecke's avatar
Martin Reinecke committed
181 182
      for (size_t i=2; i<=ndim; ++i)
        str[ndim-i] = str[ndim-i+1]*shp[ndim-i+1];
Martin Reinecke's avatar
Martin Reinecke committed
183
      }
184 185
    cmav(const cmav &other) = default;
    cmav(cmav &&other) = default;
Martin Reinecke's avatar
Martin Reinecke committed
186 187
    operator cfmav<T>() const
      {
188
      return cfmav<T>(d, ptr, {shp.begin(), shp.end()}, {str.begin(), str.end()});
Martin Reinecke's avatar
Martin Reinecke committed
189 190
      }
    const T &operator[](size_t i) const
Martin Reinecke's avatar
Martin Reinecke committed
191
      { return operator()(i); }
Martin Reinecke's avatar
Martin Reinecke committed
192
    const T &operator()(size_t i) const
Martin Reinecke's avatar
Martin Reinecke committed
193 194 195 196
      {
      static_assert(ndim==1, "ndim must be 1");
      return d[str[0]*i];
      }
Martin Reinecke's avatar
Martin Reinecke committed
197
    const T &operator()(size_t i, size_t j) const
Martin Reinecke's avatar
Martin Reinecke committed
198 199 200 201
      {
      static_assert(ndim==2, "ndim must be 2");
      return d[str[0]*i + str[1]*j];
      }
202 203 204 205 206
    const T &operator()(size_t i, size_t j, size_t k) const
      {
      static_assert(ndim==3, "ndim must be 3");
      return d[str[0]*i + str[1]*j + str[2]*k];
      }
Martin Reinecke's avatar
Martin Reinecke committed
207 208 209 210 211 212 213 214 215
    size_t shape(size_t i) const { return shp[i]; }
    const array<size_t,ndim> &shape() const { return shp; }
    size_t size() const
      {
      size_t res=1;
      for (auto v: shp) res*=v;
      return res;
      }
    ptrdiff_t stride(size_t i) const { return str[i]; }
Martin Reinecke's avatar
Martin Reinecke committed
216
    const T *data() const
Martin Reinecke's avatar
Martin Reinecke committed
217 218 219 220 221 222 223 224 225 226 227 228 229
      { return d; }
    bool last_contiguous() const
      { return (str[ndim-1]==1) || (str[ndim-1]==0); }
    bool contiguous() const
      {
      ptrdiff_t stride=1;
      for (size_t i=0; i<ndim; ++i)
        {
        if (str[ndim-1-i]!=stride) return false;
        stride *= shp[ndim-1-i];
        }
      return true;
      }
Martin Reinecke's avatar
Martin Reinecke committed
230 231 232 233 234 235
    template<typename T2> bool conformable(const cmav<T2,ndim> &other) const
      { return shp==other.shp; }
  };
template<typename T, size_t ndim> class mav: public cmav<T, ndim>
  {
  protected:
236
    using Tsp = shared_ptr<vector<T>>;
Martin Reinecke's avatar
Martin Reinecke committed
237 238
    using parent = cmav<T, ndim>;
    using parent::d;
239
    using parent::ptr;
Martin Reinecke's avatar
Martin Reinecke committed
240 241 242 243 244 245 246 247 248
    using parent::shp;
    using parent::str;

  public:
    mav(T *d_, const array<size_t,ndim> &shp_,
        const array<ptrdiff_t,ndim> &str_)
      : parent(d_, shp_, str_) {}
    mav(T *d_, const array<size_t,ndim> &shp_)
      : parent(d_, shp_) {}
249 250 251 252
    mav(const array<size_t,ndim> &shp_)
      : parent(shp_) {}
    mav(const mav &other) = default;
    mav(mav &&other) = default;
Martin Reinecke's avatar
Martin Reinecke committed
253 254
    operator fmav<T>() const
      {
255
      return fmav<T>(d, ptr, {shp.begin(), shp.end()}, {str.begin(), str.end()});
Martin Reinecke's avatar
Martin Reinecke committed
256 257 258 259 260 261 262 263 264 265 266 267 268
      }
    T &operator[](size_t i) const
      { return operator()(i); }
    T &operator()(size_t i) const
      {
      static_assert(ndim==1, "ndim must be 1");
      return d[str[0]*i];
      }
    T &operator()(size_t i, size_t j) const
      {
      static_assert(ndim==2, "ndim must be 2");
      return d[str[0]*i + str[1]*j];
      }
269 270 271 272 273
    T &operator()(size_t i, size_t j, size_t k) const
      {
      static_assert(ndim==3, "ndim must be 3");
      return d[str[0]*i + str[1]*j + str[2]*k];
      }
Martin Reinecke's avatar
Martin Reinecke committed
274 275 276 277 278 279 280
    using parent::shape;
    using parent::stride;
    using parent::size;
    T *data() const
      { return d; }
    using parent::last_contiguous;
    using parent::contiguous;
Martin Reinecke's avatar
Martin Reinecke committed
281 282 283 284 285 286 287 288 289 290
    void fill(const T &val) const
      {
      // FIXME: special cases for contiguous arrays and/or zeroing?
      if (ndim==1)
        for (size_t i=0; i<shp[0]; ++i)
          d[str[0]*i]=val;
      else if (ndim==2)
        for (size_t i=0; i<shp[0]; ++i)
          for (size_t j=0; j<shp[1]; ++j)
            d[str[0]*i + str[1]*j] = val;
291 292 293 294 295
      else if (ndim==3)
        for (size_t i=0; i<shp[0]; ++i)
          for (size_t j=0; j<shp[1]; ++j)
            for (size_t k=0; k<shp[2]; ++k)
              d[str[0]*i + str[1]*j + str[2]*k] = val;
Martin Reinecke's avatar
Martin Reinecke committed
296
      }
Martin Reinecke's avatar
Martin Reinecke committed
297
    using parent::conformable;
Martin Reinecke's avatar
Martin Reinecke committed
298 299 300 301
  };

}

Martin Reinecke's avatar
Martin Reinecke committed
302 303 304 305 306
using detail_mav::shape_t;
using detail_mav::stride_t;
using detail_mav::fmav_info;
using detail_mav::fmav;
using detail_mav::cfmav;
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
307 308
using detail_mav::mav;
using detail_mav::cmav;
Martin Reinecke's avatar
Martin Reinecke committed
309

Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
310
}
Martin Reinecke's avatar
Martin Reinecke committed
311 312

#endif