Commit 3b01452f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow direct indexing of fmav objects

parent 100e4215
Pipeline #74283 failed with stages
in 43 seconds
......@@ -34,6 +34,67 @@ namespace detail_mav {
using namespace std;
template<typename T> class membuf
{
protected:
using Tsp = shared_ptr<vector<T>>;
Tsp ptr;
const T *d;
bool rw;
membuf(const T *d_, membuf &other)
: ptr(other.ptr), d(d_), rw(other.rw) {}
membuf(const T *d_, const membuf &other)
: ptr(other.ptr), d(d_), rw(false) {}
protected:
// externally owned data pointer
membuf(T *d_, bool rw_=false)
: d(d_), rw(rw_) {}
// externally owned data pointer, nonmodifiable
membuf(const T *d_)
: d(d_), rw(false) {}
// allocate own memory
membuf(size_t sz)
: ptr(make_unique<vector<T>>(sz)), d(ptr->data()), rw(true) {}
// share another memory buffer, but read-only
membuf(const membuf &other)
: ptr(other.ptr), d(other.d), rw(false) {}
#if defined(_MSC_VER)
// MSVC is broken
membuf(membuf &other)
: ptr(other.ptr), d(other.d), rw(other.rw) {}
membuf(membuf &&other)
: ptr(move(other.ptr)), d(move(other.d)), rw(move(other.rw)) {}
#else
// share another memory buffer, using the same read/write permissions
membuf(membuf &other) = default;
// take over another memory buffer
membuf(membuf &&other) = default;
#endif
public:
// read/write access to element #i
template<typename I> T &vraw(I i)
{
MR_assert(rw, "array is not writable");
return const_cast<T *>(d)[i];
}
// read access to element #i
template<typename I> const T &operator[](I i) const
{ return d[i]; }
// read/write access to data area
const T *data() const
{ return d; }
// read access to data area
T *vdata()
{
MR_assert(rw, "array is not writable");
return const_cast<T *>(d);
}
bool writable() const { return rw; }
};
class fmav_info
{
public:
......@@ -52,6 +113,19 @@ class fmav_info
res*=sz;
return res;
}
static stride_t shape2stride(const shape_t &shp)
{
auto ndim = shp.size();
stride_t res(ndim);
res[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
res[ndim-i] = res[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
return res;
}
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
{ return str[dim]*n; }
public:
fmav_info(const shape_t &shape_, const stride_t &stride_)
......@@ -61,14 +135,7 @@ class fmav_info
MR_assert(shp.size()==str.size(), "dimensions mismatch");
}
fmav_info(const shape_t &shape_)
: shp(shape_), str(shape_.size()), sz(prod(shp))
{
auto ndim = shp.size();
MR_assert(ndim>0, "at least 1D required");
str[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
}
: fmav_info(shape_, shape2stride(shape_)) {}
size_t ndim() const { return shp.size(); }
size_t size() const { return sz; }
const shape_t &shape() const { return shp; }
......@@ -90,57 +157,76 @@ class fmav_info
}
bool conformable(const fmav_info &other) const
{ return shp==other.shp; }
template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{
MR_assert(ndim()==sizeof...(ns), "incorrect number of indices");
return getIdx(0, ns...);
}
};
template<typename T> class membuf
template<size_t ndim> class mav_info
{
protected:
using Tsp = shared_ptr<vector<T>>;
Tsp ptr;
const T *d;
bool rw;
static_assert(ndim>0, "at least 1D required");
membuf(const T *d_, membuf &other)
: ptr(other.ptr), d(d_), rw(other.rw) {}
membuf(const T *d_, const membuf &other)
: ptr(other.ptr), d(d_), rw(false) {}
using shape_t = array<size_t, ndim>;
using stride_t = array<ptrdiff_t, ndim>;
protected:
membuf(T *d_, bool rw_=false)
: d(d_), rw(rw_) {}
membuf(const T *d_)
: d(d_), rw(false) {}
membuf(size_t sz)
: ptr(make_unique<vector<T>>(sz)), d(ptr->data()), rw(true) {}
membuf(const membuf &other)
: ptr(other.ptr), d(other.d), rw(false) {}
#if defined(_MSC_VER)
// MSVC is broken
membuf(membuf &other)
: ptr(other.ptr), d(other.d), rw(other.rw) {}
membuf(membuf &&other)
: ptr(move(other.ptr)), d(move(other.d)), rw(move(other.rw)) {}
#else
membuf(membuf &other) = default;
membuf(membuf &&other) = default;
#endif
shape_t shp;
stride_t str;
size_t sz;
static size_t prod(const shape_t &shape)
{
size_t res=1;
for (auto sz: shape)
res*=sz;
return res;
}
static stride_t shape2stride(const shape_t &shp)
{
stride_t res;
res[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
res[ndim-i] = res[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
return res;
}
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
{ return str[dim]*n; }
public:
template<typename I> T &vraw(I i)
mav_info(const shape_t &shape_, const stride_t &stride_)
: shp(shape_), str(stride_), sz(prod(shp)) {}
mav_info(const shape_t &shape_)
: mav_info(shape_, shape2stride(shape_)) {}
size_t size() const { return sz; }
const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; }
const stride_t &stride() const { return str; }
const ptrdiff_t &stride(size_t i) const { return str[i]; }
bool last_contiguous() const
{ return (str.back()==1); }
bool contiguous() const
{
MR_assert(rw, "array is not writable");
return const_cast<T *>(d)[i];
ptrdiff_t stride=1;
for (size_t i=0; i<ndim; ++i)
{
if (str[ndim-1-i]!=stride) return false;
stride *= ptrdiff_t(shp[ndim-1-i]);
}
return true;
}
template<typename I> const T &operator[](I i) const
{ return d[i]; }
const T *data() const
{ return d; }
T *vdata()
bool conformable(const mav_info &other) const
{ return shp==other.shp; }
bool conformable(const shape_t &other) const
{ return shp==other; }
template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{
MR_assert(rw, "array is not writable");
return const_cast<T *>(d);
static_assert(ndim==sizeof...(ns), "incorrect number of indices");
return getIdx(0, ns...);
}
bool writable() const { return rw; }
};
// "mav" stands for "multidimensional array view"
......@@ -148,6 +234,7 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
{
protected:
using membuf<T>::d;
using membuf<T>::vraw;
void subdata(const shape_t &i0, const shape_t &extent,
shape_t &nshp, stride_t &nstr, ptrdiff_t &nofs) const
......@@ -174,6 +261,7 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
}
public:
using membuf<T>::operator[], membuf<T>::vdata, membuf<T>::data;
fmav(const T *d_, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(d_) {}
fmav(const T *d_, const shape_t &shp_)
......@@ -202,10 +290,15 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
: fmav_info(shp_, str_), membuf<T>(buf) {}
fmav(const membuf<T> &buf, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(buf) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, membuf<T> &mb)
: fmav_info(shp_, str_), membuf<T>(d_, mb) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, const membuf<T> &mb)
: fmav_info(shp_, str_), membuf<T>(d_, mb) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, membuf<T> &buf)
: fmav_info(shp_, str_), membuf<T>(d_, buf) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, const membuf<T> &buf)
: fmav_info(shp_, str_), membuf<T>(d_, buf) {}
template<typename... Ns> const T &operator()(Ns... ns) const
{ return operator[](idx(ns...)); }
template<typename... Ns> T &v(Ns... ns)
{ return vraw(idx(ns...)); }
fmav subarray(const shape_t &i0, const shape_t &extent)
{
......@@ -225,71 +318,6 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
}
};
template<size_t ndim> class mav_info
{
protected:
static_assert(ndim>0, "at least 1D required");
using shape_t = array<size_t, ndim>;
using stride_t = array<ptrdiff_t, ndim>;
shape_t shp;
stride_t str;
size_t sz;
static size_t prod(const shape_t &shape)
{
size_t res=1;
for (auto sz: shape)
res*=sz;
return res;
}
static stride_t shape2stride(const shape_t &shp)
{
stride_t res;
res[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
res[ndim-i] = res[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
return res;
}
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
{ return str[dim]*n; }
public:
mav_info(const shape_t &shape_, const stride_t &stride_)
: shp(shape_), str(stride_), sz(prod(shp)) {}
mav_info(const shape_t &shape_)
: shp(shape_), str(shape2stride(shp)), sz(prod(shp)) {}
size_t size() const { return sz; }
const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; }
const stride_t &stride() const { return str; }
const ptrdiff_t &stride(size_t i) const { return str[i]; }
bool last_contiguous() const
{ return (str.back()==1); }
bool contiguous() const
{
ptrdiff_t stride=1;
for (size_t i=0; i<ndim; ++i)
{
if (str[ndim-1-i]!=stride) return false;
stride *= ptrdiff_t(shp[ndim-1-i]);
}
return true;
}
bool conformable(const mav_info &other) const
{ return shp==other.shp; }
bool conformable(const shape_t &other) const
{ return shp==other; }
template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{
static_assert(ndim==sizeof...(ns), "incorrect number of indices");
return getIdx(0, ns...);
}
};
template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membuf<T>
{
// static_assert((ndim>0) && (ndim<4), "only supports 1D, 2D, and 3D arrays");
......@@ -368,12 +396,8 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
}
public:
using membuf<T>::operator[];
using membuf<T>::vdata;
using membuf<T>::data;
using mav_info<ndim>::contiguous;
using mav_info<ndim>::size;
using mav_info<ndim>::idx;
using membuf<T>::operator[], membuf<T>::vdata, membuf<T>::data;
using mav_info<ndim>::contiguous, mav_info<ndim>::size, mav_info<ndim>::idx;
using mav_info<ndim>::conformable;
mav(const T *d_, const shape_t &shp_, const stride_t &str_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment