Commit 485bb50d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

template tweaks

parent 59d123e4
Pipeline #69743 failed with stages
in 1 minute and 18 seconds
......@@ -175,6 +175,10 @@ template<size_t ndim> class mav_info
res*=sz;
return res;
}
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
{ return str[dim]*n; }
public:
mav_info(const shape_t &shape_, const stride_t &stride_)
......@@ -207,20 +211,10 @@ template<size_t ndim> class mav_info
}
bool conformable(const mav_info &other) const
{ return shp==other.shp; }
ptrdiff_t idx(size_t i) const
template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{
static_assert(ndim==1, "ndim must be 1");
return str[0]*i;
}
ptrdiff_t idx(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return str[0]*i + str[1]*j;
}
ptrdiff_t idx(size_t i, size_t j, size_t k) const
{
static_assert(ndim==3, "ndim must be 3");
return str[0]*i + str[1]*j + str[2]*k;
static_assert(ndim==sizeof...(ns), "incorrect number of indices");
return getIdx(0, ns...);
}
};
......@@ -242,6 +236,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using membuf<T>::operator[];
using membuf<T>::vdata;
using membuf<T>::data;
using mav_info<ndim>::contiguous;
using mav_info<ndim>::size;
using mav_info<ndim>::idx;
......@@ -266,21 +261,19 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
{
return fmav<T>(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()});
}
const T &operator()(size_t i) const
{ return operator[](idx(i)); }
const T &operator()(size_t i, size_t j) const
{ return operator[](idx(i,j)); }
const T &operator()(size_t i, size_t j, size_t k) const
{ return operator[](idx(i,j,k)); }
T &v(size_t i)
{ return vraw(idx(i)); }
T &v(size_t i, size_t j)
{ return vraw(idx(i,j)); }
T &v(size_t i, size_t j, size_t k)
{ return vraw(idx(i,j,k)); }
template<typename... Ns> const T &operator()(Ns... ns) const
{ return operator[](idx(ns...)); }
template<typename... Ns> T &v(Ns... ns)
{ return vraw(idx(ns...)); }
template<typename Func> void apply(Func func)
{
T *d2 = vdata();
if (contiguous())
{
for (auto &v: {d2, d2+size()})
func(*v);
return;
}
// FIXME: special cases for contiguous arrays and/or zeroing?
if (ndim==1)
for (size_t i=0; i<shp[0]; ++i)
......@@ -309,6 +302,11 @@ template<typename T, size_t ndim> class MavIter
ptrdiff_t idx_;
bool done_;
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
{ return str[dim]*n; }
public:
MavIter(const fmav<T> &mav_)
: mav(mav_), pos(mav.ndim()-ndim,0), idx_(0), done_(false)
......@@ -342,33 +340,12 @@ template<typename T, size_t ndim> class MavIter
done_=true;
}
size_t shape(size_t i) const { return shp[i]; }
ptrdiff_t idx(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return idx_+i*str[0];
}
ptrdiff_t idx(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return idx_+i*str[0]+j*str[1];
}
ptrdiff_t idx(size_t i, size_t j, size_t k) const
{
static_assert(ndim==3, "ndim must be 3");
return idx_+i*str[0]+j*str[1]+k*str[2];
}
const T &operator()(size_t i) const
{ return mav[idx(i)]; }
const T &operator()(size_t i, size_t j) const
{ return mav[idx(i,j)]; }
const T &operator()(size_t i, size_t j, size_t k) const
{ return mav[idx(i,j,k)]; }
T &v(size_t i)
{ return mav.vraw(idx(i)); }
T &v(size_t i, size_t j)
{ return mav.vraw(idx(i,j)); }
T &v(size_t i, size_t j, size_t k)
{ return mav.vraw(idx(i,j,k)); }
template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{ return idx_ + getIdx(0, ns...); }
template<typename... Ns> const T &operator()(Ns... ns) const
{ return mav[idx(ns...)]; }
template<typename... Ns> T &v(Ns... ns)
{ return mav.vraw(idx(ns...)); }
};
}
......
......@@ -52,17 +52,17 @@ namespace mr {
namespace detail_simd {
template<typename T> constexpr bool vectorizable = false;
template<> constexpr bool vectorizable<float> = true;
template<> constexpr bool vectorizable<double> = true;
template<> constexpr bool vectorizable<int8_t> = true;
template<> constexpr bool vectorizable<uint8_t> = true;
template<> constexpr bool vectorizable<int16_t> = true;
template<> constexpr bool vectorizable<uint16_t> = true;
template<> constexpr bool vectorizable<int32_t> = true;
template<> constexpr bool vectorizable<uint32_t> = true;
template<> constexpr bool vectorizable<int64_t> = true;
template<> constexpr bool vectorizable<uint64_t> = true;
template<typename T> constexpr static const bool vectorizable = false;
template<> constexpr static const bool vectorizable<float> = true;
template<> constexpr static const bool vectorizable<double> = true;
template<> constexpr static const bool vectorizable<int8_t> = true;
template<> constexpr static const bool vectorizable<uint8_t> = true;
template<> constexpr static const bool vectorizable<int16_t> = true;
template<> constexpr static const bool vectorizable<uint16_t> = true;
template<> constexpr static const bool vectorizable<int32_t> = true;
template<> constexpr static const bool vectorizable<uint32_t> = true;
template<> constexpr static const bool vectorizable<int64_t> = true;
template<> constexpr static const bool vectorizable<uint64_t> = true;
template<typename T, size_t reglen> constexpr size_t vlen
= vectorizable<T> ? reglen/sizeof(T) : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment