Commit 100e4215 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

restructure

parent 1da6d79c
......@@ -46,6 +46,8 @@ namespace py = pybind11;
namespace {
using shape_t = fmav_info::shape_t;
template<size_t nd1, size_t nd2> shape_t repl_dim(const shape_t &s,
const array<size_t,nd1> &si, const array<size_t,nd2> &so)
{
......
......@@ -21,8 +21,7 @@
namespace {
using mr::shape_t;
using mr::stride_t;
using shape_t = mr::fmav_info::shape_t;
using mr::fmav;
using mr::to_fmav;
using mr::get_optional_Pyarr;
......
......@@ -10,6 +10,9 @@ namespace mr {
namespace detail_pybind {
using shape_t=fmav_info::shape_t;
using stride_t=fmav_info::stride_t;
namespace py = pybind11;
template<typename T> bool isPyarr(const py::object &obj)
......
......@@ -34,11 +34,12 @@ namespace detail_mav {
using namespace std;
using shape_t = vector<size_t>;
using stride_t = vector<ptrdiff_t>;
class fmav_info
{
public:
using shape_t = vector<size_t>;
using stride_t = vector<ptrdiff_t>;
protected:
shape_t shp;
stride_t str;
......@@ -104,7 +105,7 @@ template<typename T> class membuf
membuf(const T *d_, const membuf &other)
: ptr(other.ptr), d(d_), rw(false) {}
public:
protected:
membuf(T *d_, bool rw_=false)
: d(d_), rw(rw_) {}
membuf(const T *d_)
......@@ -124,6 +125,7 @@ template<typename T> class membuf
membuf(membuf &&other) = default;
#endif
public:
template<typename I> T &vraw(I i)
{
MR_assert(rw, "array is not writable");
......@@ -144,6 +146,33 @@ template<typename T> class membuf
// "mav" stands for "multidimensional array view"
template<typename T> class fmav: public fmav_info, public membuf<T>
{
protected:
using membuf<T>::d;
void subdata(const shape_t &i0, const shape_t &extent,
shape_t &nshp, stride_t &nstr, ptrdiff_t &nofs) const
{
auto ndim = fmav_info::ndim();
MR_assert(i0.size()==ndim, "bad domensionality");
MR_assert(extent.size()==ndim, "bad domensionality");
size_t n0=0;
for (auto x:extent) if (x==0) ++n0;
nofs=0;
nshp.resize(ndim-n0);
nstr.resize(ndim-n0);
for (size_t i=0, i2=0; i<ndim; ++i)
{
MR_assert(i0[i]<shp[i], "bad subset");
nofs+=i0[i]*str[i];
if (extent[i]!=0)
{
MR_assert(i0[i]+extent[i2]<=shp[i], "bad subset");
nshp[i2] = extent[i]; nstr[i2]=str[i];
++i2;
}
}
}
public:
fmav(const T *d_, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(d_) {}
......@@ -173,11 +202,34 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
: fmav_info(shp_, str_), membuf<T>(buf) {}
fmav(const membuf<T> &buf, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(buf) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, membuf<T> &mb)
: fmav_info(shp_, str_), membuf<T>(d_, mb) {}
fmav(const shape_t &shp_, const stride_t &str_, const T *d_, const membuf<T> &mb)
: fmav_info(shp_, str_), membuf<T>(d_, mb) {}
fmav subarray(const shape_t &i0, const shape_t &extent)
{
shape_t nshp;
stride_t nstr;
ptrdiff_t nofs;
subdata(i0, extent, nshp, nstr, nofs);
return fmav(nshp, nstr, d+nofs, *this);
}
fmav subarray(const shape_t &i0, const shape_t &extent) const
{
shape_t nshp;
stride_t nstr;
ptrdiff_t nofs;
subdata(i0, extent, nshp, nstr, nofs);
return fmav(nshp, nstr, d+nofs, *this);
}
};
template<size_t ndim> class mav_info
{
protected:
static_assert(ndim>0, "at least 1D required");
using shape_t = array<size_t, ndim>;
using stride_t = array<ptrdiff_t, ndim>;
......@@ -192,6 +244,14 @@ template<size_t ndim> class mav_info
res*=sz;
return res;
}
static stride_t shape2stride(const shape_t &shp)
{
stride_t res;
res[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
res[ndim-i] = res[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
return res;
}
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
......@@ -199,16 +259,9 @@ template<size_t ndim> class mav_info
public:
mav_info(const shape_t &shape_, const stride_t &stride_)
: shp(shape_), str(stride_), sz(prod(shp))
{ MR_assert(shp.size()==str.size(), "dimensions mismatch"); }
: shp(shape_), str(stride_), sz(prod(shp)) {}
mav_info(const shape_t &shape_)
: shp(shape_), sz(prod(shp))
{
MR_assert(ndim>0, "at least 1D required");
str[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
}
: shp(shape_), str(shape2stride(shp)), sz(prod(shp)) {}
size_t size() const { return sz; }
const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; }
......@@ -399,7 +452,7 @@ template<typename T, size_t ndim> class MavIter
fmav<T> mav;
array<size_t, ndim> shp;
array<ptrdiff_t, ndim> str;
shape_t pos;
fmav_info::shape_t pos;
ptrdiff_t idx_;
bool done_;
......@@ -451,8 +504,6 @@ template<typename T, size_t ndim> class MavIter
}
using detail_mav::shape_t;
using detail_mav::stride_t;
using detail_mav::fmav_info;
using detail_mav::fmav;
using detail_mav::mav;
......
......@@ -66,6 +66,8 @@ namespace mr {
namespace detail_fft {
using shape_t=fmav_info::shape_t;
constexpr bool FORWARD = true,
BACKWARD = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment