/* * This file is part of the MR utility library. * * This code is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This code is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this code; if not, write to the Free Software * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA */ /* Copyright (C) 2019-2020 Max-Planck-Society Author: Martin Reinecke */ #ifndef MRUTIL_MAV_H #define MRUTIL_MAV_H #include #include #include #include #include "mr_util/infra/error_handling.h" namespace mr { namespace detail_mav { using namespace std; using shape_t = vector; using stride_t = vector; class fmav_info { protected: shape_t shp; stride_t str; size_t sz; static size_t prod(const shape_t &shape) { size_t res=1; for (auto sz: shape) res*=sz; return res; } public: fmav_info(const shape_t &shape_, const stride_t &stride_) : shp(shape_), str(stride_), sz(prod(shp)) { MR_assert(shp.size()>0, "at least 1D required"); MR_assert(shp.size()==str.size(), "dimensions mismatch"); } fmav_info(const shape_t &shape_) : shp(shape_), str(shape_.size()), sz(prod(shp)) { auto ndim = shp.size(); MR_assert(ndim>0, "at least 1D required"); str[ndim-1]=1; for (size_t i=2; i<=ndim; ++i) str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]); } size_t ndim() const { return shp.size(); } size_t size() const { return sz; } const shape_t &shape() const { return shp; } size_t shape(size_t i) const { return shp[i]; } const stride_t &stride() const { return str; } const ptrdiff_t &stride(size_t i) const { return str[i]; } bool last_contiguous() const { return (str.back()==1); } bool contiguous() const { auto ndim = shp.size(); ptrdiff_t stride=1; for (size_t i=0; i class membuf { protected: using Tsp = shared_ptr>; Tsp ptr; const T *d; bool rw; membuf(const T *d_, membuf &other) : ptr(other.ptr), d(d_), rw(other.rw) {} membuf(const T *d_, const membuf &other) : ptr(other.ptr), d(d_), rw(false) {} public: membuf(T *d_, bool rw_=false) : d(d_), rw(rw_) {} membuf(const T *d_) : d(d_), rw(false) {} membuf(size_t sz) : ptr(make_unique>(sz)), d(ptr->data()), rw(true) {} membuf(const membuf &other) : ptr(other.ptr), d(other.d), rw(false) {} membuf(membuf &other) = default; membuf(membuf &&other) = default; template T &vraw(I i) { MR_assert(rw, "array is not writable"); return const_cast(d)[i]; } template const T &operator[](I i) const { return d[i]; } const T *data() const { return d; } T *vdata() { MR_assert(rw, "array is not writable"); return const_cast(d); } bool writable() { return rw; } }; // "mav" stands for "multidimensional array view" template class fmav: public fmav_info, public membuf { public: fmav(const T *d_, const shape_t &shp_, const stride_t &str_) : fmav_info(shp_, str_), membuf(d_) {} fmav(const T *d_, const shape_t &shp_) : fmav_info(shp_), membuf(d_) {} fmav(T *d_, const shape_t &shp_, const stride_t &str_, bool rw_) : fmav_info(shp_, str_), membuf(d_,rw_) {} fmav(T *d_, const shape_t &shp_, bool rw_) : fmav_info(shp_), membuf(d_,rw_) {} fmav(const shape_t &shp_) : fmav_info(shp_), membuf(size()) {} fmav(T* d_, const fmav_info &info, bool rw_=false) : fmav_info(info), membuf(d_, rw_) {} fmav(const T* d_, const fmav_info &info) : fmav_info(info), membuf(d_) {} fmav(const fmav &other) = default; fmav(fmav &other) = default; fmav(fmav &&other) = default; fmav(membuf &buf, const shape_t &shp_, const stride_t &str_) : fmav_info(shp_, str_), membuf(buf) {} fmav(const membuf &buf, const shape_t &shp_, const stride_t &str_) : fmav_info(shp_, str_), membuf(buf) {} }; template class mav_info { protected: using shape_t = array; using stride_t = array; shape_t shp; stride_t str; size_t sz; static size_t prod(const shape_t &shape) { size_t res=1; for (auto sz: shape) res*=sz; return res; } template ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const { return str[dim]*n + getIdx(dim+1, ns...); } ptrdiff_t getIdx(size_t dim, size_t n) const { return str[dim]*n; } public: mav_info(const shape_t &shape_, const stride_t &stride_) : shp(shape_), str(stride_), sz(prod(shp)) { MR_assert(shp.size()==str.size(), "dimensions mismatch"); } mav_info(const shape_t &shape_) : shp(shape_), sz(prod(shp)) { MR_assert(ndim>0, "at least 1D required"); str[ndim-1]=1; for (size_t i=2; i<=ndim; ++i) str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]); } size_t size() const { return sz; } const shape_t &shape() const { return shp; } size_t shape(size_t i) const { return shp[i]; } const stride_t &stride() const { return str; } const ptrdiff_t &stride(size_t i) const { return str[i]; } bool last_contiguous() const { return (str.back()==1); } bool contiguous() const { ptrdiff_t stride=1; for (size_t i=0; i ptrdiff_t idx(Ns... ns) const { static_assert(ndim==sizeof...(ns), "incorrect number of indices"); return getIdx(0, ns...); } }; template class mav: public mav_info, public membuf { // static_assert((ndim>0) && (ndim<4), "only supports 1D, 2D, and 3D arrays"); protected: using typename mav_info::shape_t; using typename mav_info::stride_t; using membuf::d; using membuf::ptr; using mav_info::shp; using mav_info::str; using membuf::rw; using membuf::vraw; template void applyHelper(ptrdiff_t idx, Func func) { if constexpr (idim+1(idx+i*str[idim], func); else { T *d2 = vdata(); for (size_t i=0; i void applyHelper(ptrdiff_t idx, Func func) const { if constexpr (idim+1(idx+i*str[idim], func); else { const T *d2 = data(); for (size_t i=0; i void applyHelper(ptrdiff_t idx, ptrdiff_t idx2, const mav &other, Func func) { if constexpr (idim==0) MR_assert(conformable(other), "dimension mismatch"); if constexpr (idim+1(idx+i*str[idim], idx2+i*other.str[idim], other, func); else { T *d2 = vdata(); const T2 *d3 = other.data(); for (size_t i=0; i void subdata(const shape_t &i0, const shape_t &extent, array &nshp, array &nstr, ptrdiff_t &nofs) const { size_t n0=0; for (auto x:extent) if (x==0)++n0; MR_assert(n0+nd2==ndim, "bad extent"); nofs=0; for (size_t i=0, i2=0; i::operator[]; using membuf::vdata; using membuf::data; using mav_info::contiguous; using mav_info::size; using mav_info::idx; using mav_info::conformable; mav(const T *d_, const shape_t &shp_, const stride_t &str_) : mav_info(shp_, str_), membuf(d_) {} mav(T *d_, const shape_t &shp_, const stride_t &str_, bool rw_=false) : mav_info(shp_, str_), membuf(d_, rw_) {} mav(const T *d_, const shape_t &shp_) : mav_info(shp_), membuf(d_) {} mav(T *d_, const shape_t &shp_, bool rw_=false) : mav_info(shp_), membuf(d_, rw_) {} mav(const array &shp_) : mav_info(shp_), membuf(size()) {} mav(const mav &other) = default; mav(mav &other) = default; mav(mav &&other) = default; mav(const shape_t &shp_, const stride_t &str_, const T *d_, membuf &mb) : mav_info(shp_, str_), membuf(d_, mb) {} mav(const shape_t &shp_, const stride_t &str_, const T *d_, const membuf &mb) : mav_info(shp_, str_), membuf(d_, mb) {} operator fmav() const { return fmav(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()}); } operator fmav() { return fmav(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()}); } template const T &operator()(Ns... ns) const { return operator[](idx(ns...)); } template T &v(Ns... ns) { return vraw(idx(ns...)); } template void apply(Func func) { if (contiguous()) { T *d2 = vdata(); for (auto v=d2; v!=d2+size(); ++v) func(*v); return; } applyHelper<0,Func>(0,func); } template void apply (const mav &other,Func func) { applyHelper<0,T2,Func>(0,0,other,func); } void fill(const T &val) { apply([val](T &v){v=val;}); } template mav subarray(const shape_t &i0, const shape_t &extent) { array nshp; array nstr; ptrdiff_t nofs; subdata (i0, extent, nshp, nstr, nofs); return mav (nshp, nstr, d+nofs, *this); } template mav subarray(const shape_t &i0, const shape_t &extent) const { array nshp; array nstr; ptrdiff_t nofs; subdata (i0, extent, nshp, nstr, nofs); return mav (nshp, nstr, d+nofs, *this); } }; template class MavIter { protected: fmav mav; array shp; array str; shape_t pos; ptrdiff_t idx_; bool done_; template ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const { return str[dim]*n + getIdx(dim+1, ns...); } ptrdiff_t getIdx(size_t dim, size_t n) const { return str[dim]*n; } public: MavIter(const fmav &mav_) : mav(mav_), pos(mav.ndim()-ndim,0), idx_(0), done_(false) { for (size_t i=0; i &mav_) : mav(mav_), pos(mav.ndim()-ndim,0), idx_(0), done_(false) { for (size_t i=0; i=0; --i) { idx_+=mav.stride(i); if (++pos[i] ptrdiff_t idx(Ns... ns) const { return idx_ + getIdx(0, ns...); } template const T &operator()(Ns... ns) const { return mav[idx(ns...)]; } template T &v(Ns... ns) { return mav.vraw(idx(ns...)); } }; } using detail_mav::shape_t; using detail_mav::stride_t; using detail_mav::fmav_info; using detail_mav::fmav; using detail_mav::mav; using detail_mav::MavIter; } #endif