Commit 100e4215 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

restructure

parent 1da6d79c
Pipeline #74276 passed with stages
in 12 minutes and 10 seconds
...@@ -46,6 +46,8 @@ namespace py = pybind11; ...@@ -46,6 +46,8 @@ namespace py = pybind11;
namespace { namespace {
using shape_t = fmav_info::shape_t;
template<size_t nd1, size_t nd2> shape_t repl_dim(const shape_t &s, template<size_t nd1, size_t nd2> shape_t repl_dim(const shape_t &s,
const array<size_t,nd1> &si, const array<size_t,nd2> &so) const array<size_t,nd1> &si, const array<size_t,nd2> &so)
{ {
......
...@@ -21,8 +21,7 @@ ...@@ -21,8 +21,7 @@
namespace { namespace {
using mr::shape_t; using shape_t = mr::fmav_info::shape_t;
using mr::stride_t;
using mr::fmav; using mr::fmav;
using mr::to_fmav; using mr::to_fmav;
using mr::get_optional_Pyarr; using mr::get_optional_Pyarr;
......
...@@ -10,6 +10,9 @@ namespace mr { ...@@ -10,6 +10,9 @@ namespace mr {
namespace detail_pybind { namespace detail_pybind {
using shape_t=fmav_info::shape_t;
using stride_t=fmav_info::stride_t;
namespace py = pybind11; namespace py = pybind11;
template<typename T> bool isPyarr(const py::object &obj) template<typename T> bool isPyarr(const py::object &obj)
......
...@@ -34,11 +34,12 @@ namespace detail_mav { ...@@ -34,11 +34,12 @@ namespace detail_mav {
using namespace std; using namespace std;
using shape_t = vector<size_t>;
using stride_t = vector<ptrdiff_t>;
class fmav_info class fmav_info
{ {
public:
using shape_t = vector<size_t>;
using stride_t = vector<ptrdiff_t>;
protected: protected:
shape_t shp; shape_t shp;
stride_t str; stride_t str;
...@@ -104,7 +105,7 @@ template<typename T> class membuf ...@@ -104,7 +105,7 @@ template<typename T> class membuf
membuf(const T *d_, const membuf &other) membuf(const T *d_, const membuf &other)
: ptr(other.ptr), d(d_), rw(false) {} : ptr(other.ptr), d(d_), rw(false) {}
public: protected:
membuf(T *d_, bool rw_=false) membuf(T *d_, bool rw_=false)
: d(d_), rw(rw_) {} : d(d_), rw(rw_) {}
membuf(const T *d_) membuf(const T *d_)
...@@ -124,6 +125,7 @@ template<typename T> class membuf ...@@ -124,6 +125,7 @@ template<typename T> class membuf
membuf(membuf &&other) = default; membuf(membuf &&other) = default;
#endif #endif
public:
template<typename I> T &vraw(I i) template<typename I> T &vraw(I i)
{ {
MR_assert(rw, "array is not writable"); MR_assert(rw, "array is not writable");
...@@ -144,6 +146,33 @@ template<typename T> class membuf ...@@ -144,6 +146,33 @@ template<typename T> class membuf
// "mav" stands for "multidimensional array view" // "mav" stands for "multidimensional array view"
template<typename T> class fmav: public fmav_info, public membuf<T> template<typename T> class fmav: public fmav_info, public membuf<T>
{ {
protected:
using membuf<T>::d;
void subdata(const shape_t &i0, const shape_t &extent,
shape_t &nshp, stride_t &nstr, ptrdiff_t &nofs) const
{
auto ndim = fmav_info::ndim();
MR_assert(i0.size()==ndim, "bad domensionality");
MR_assert(extent.size()==ndim, "bad domensionality");
size_t n0=0;
for (auto x:extent) if (x==0) ++n0;
nofs=0;
nshp.resize(ndim-n0);
nstr.resize(ndim-n0);
for (size_t i=0, i2=0; i<ndim; ++i)
{
MR_assert(i0[i]<shp[i], "bad subset");
nofs+=i0[i]*str[i];
if (extent[i]!=0)
{
MR_assert(i0[i]+extent[i2]<=shp[i], "bad subset");
nshp[i2] = extent[i]; nstr[i2]=str[i];
++i2;
}
}
}
public: public:
fmav(const T *d_, const shape_t &shp_, const stride_t &str_) fmav(const T *d_, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(d_) {} : fmav_info(shp_, str_), membuf<T>(d_) {}
...@@ -173,11 +202,34 @@ template<typename T> class fmav: public fmav_info, public membuf<T> ...@@ -173,11 +202,34 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
: fmav_info(shp_, str_), membuf<T>(buf) {} : fmav_info(shp_, str_), membuf<T>(buf) {}
fmav(const membuf<T> &buf, const shape_t &shp_, const stride_t &str_) fmav(const membuf<T> &buf, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(buf) {} : fmav_info(shp_, str_), membuf<T>(buf) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, membuf<T> &mb)
: fmav_info(shp_, str_), membuf<T>(d_, mb) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, const membuf<T> &mb)
: fmav_info(shp_, str_), membuf<T>(d_, mb) {}
fmav subarray(const shape_t &i0, const shape_t &extent)
{
shape_t nshp;
stride_t nstr;
ptrdiff_t nofs;
subdata(i0, extent, nshp, nstr, nofs);
return fmav(nshp, nstr, d+nofs, *this);
}
fmav subarray(const shape_t &i0, const shape_t &extent) const
{
shape_t nshp;
stride_t nstr;
ptrdiff_t nofs;
subdata(i0, extent, nshp, nstr, nofs);
return fmav(nshp, nstr, d+nofs, *this);
}
}; };
template<size_t ndim> class mav_info template<size_t ndim> class mav_info
{ {
protected: protected:
static_assert(ndim>0, "at least 1D required");
using shape_t = array<size_t, ndim>; using shape_t = array<size_t, ndim>;
using stride_t = array<ptrdiff_t, ndim>; using stride_t = array<ptrdiff_t, ndim>;
...@@ -192,6 +244,14 @@ template<size_t ndim> class mav_info ...@@ -192,6 +244,14 @@ template<size_t ndim> class mav_info
res*=sz; res*=sz;
return res; return res;
} }
static stride_t shape2stride(const shape_t &shp)
{
stride_t res;
res[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
res[ndim-i] = res[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
return res;
}
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); } { return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const ptrdiff_t getIdx(size_t dim, size_t n) const
...@@ -199,16 +259,9 @@ template<size_t ndim> class mav_info ...@@ -199,16 +259,9 @@ template<size_t ndim> class mav_info
public: public:
mav_info(const shape_t &shape_, const stride_t &stride_) mav_info(const shape_t &shape_, const stride_t &stride_)
: shp(shape_), str(stride_), sz(prod(shp)) : shp(shape_), str(stride_), sz(prod(shp)) {}
{ MR_assert(shp.size()==str.size(), "dimensions mismatch"); }
mav_info(const shape_t &shape_) mav_info(const shape_t &shape_)
: shp(shape_), sz(prod(shp)) : shp(shape_), str(shape2stride(shp)), sz(prod(shp)) {}
{
MR_assert(ndim>0, "at least 1D required");
str[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
}
size_t size() const { return sz; } size_t size() const { return sz; }
const shape_t &shape() const { return shp; } const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; } size_t shape(size_t i) const { return shp[i]; }
...@@ -399,7 +452,7 @@ template<typename T, size_t ndim> class MavIter ...@@ -399,7 +452,7 @@ template<typename T, size_t ndim> class MavIter
fmav<T> mav; fmav<T> mav;
array<size_t, ndim> shp; array<size_t, ndim> shp;
array<ptrdiff_t, ndim> str; array<ptrdiff_t, ndim> str;
shape_t pos; fmav_info::shape_t pos;
ptrdiff_t idx_; ptrdiff_t idx_;
bool done_; bool done_;
...@@ -451,8 +504,6 @@ template<typename T, size_t ndim> class MavIter ...@@ -451,8 +504,6 @@ template<typename T, size_t ndim> class MavIter
} }
using detail_mav::shape_t;
using detail_mav::stride_t;
using detail_mav::fmav_info; using detail_mav::fmav_info;
using detail_mav::fmav; using detail_mav::fmav;
using detail_mav::mav; using detail_mav::mav;
......
...@@ -66,6 +66,8 @@ namespace mr { ...@@ -66,6 +66,8 @@ namespace mr {
namespace detail_fft { namespace detail_fft {
using shape_t=fmav_info::shape_t;
constexpr bool FORWARD = true, constexpr bool FORWARD = true,
BACKWARD = false; BACKWARD = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment