Commit 485bb50d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

template tweaks

parent 59d123e4
Pipeline #69743 failed with stages
in 1 minute and 18 seconds
...@@ -175,6 +175,10 @@ template<size_t ndim> class mav_info ...@@ -175,6 +175,10 @@ template<size_t ndim> class mav_info
res*=sz; res*=sz;
return res; return res;
} }
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
{ return str[dim]*n; }
public: public:
mav_info(const shape_t &shape_, const stride_t &stride_) mav_info(const shape_t &shape_, const stride_t &stride_)
...@@ -207,20 +211,10 @@ template<size_t ndim> class mav_info ...@@ -207,20 +211,10 @@ template<size_t ndim> class mav_info
} }
bool conformable(const mav_info &other) const bool conformable(const mav_info &other) const
{ return shp==other.shp; } { return shp==other.shp; }
ptrdiff_t idx(size_t i) const template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{ {
static_assert(ndim==1, "ndim must be 1"); static_assert(ndim==sizeof...(ns), "incorrect number of indices");
return str[0]*i; return getIdx(0, ns...);
}
ptrdiff_t idx(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return str[0]*i + str[1]*j;
}
ptrdiff_t idx(size_t i, size_t j, size_t k) const
{
static_assert(ndim==3, "ndim must be 3");
return str[0]*i + str[1]*j + str[2]*k;
} }
}; };
...@@ -242,6 +236,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -242,6 +236,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using membuf<T>::operator[]; using membuf<T>::operator[];
using membuf<T>::vdata; using membuf<T>::vdata;
using membuf<T>::data; using membuf<T>::data;
using mav_info<ndim>::contiguous;
using mav_info<ndim>::size; using mav_info<ndim>::size;
using mav_info<ndim>::idx; using mav_info<ndim>::idx;
...@@ -266,21 +261,19 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -266,21 +261,19 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
{ {
return fmav<T>(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()}); return fmav<T>(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()});
} }
const T &operator()(size_t i) const template<typename... Ns> const T &operator()(Ns... ns) const
{ return operator[](idx(i)); } { return operator[](idx(ns...)); }
const T &operator()(size_t i, size_t j) const template<typename... Ns> T &v(Ns... ns)
{ return operator[](idx(i,j)); } { return vraw(idx(ns...)); }
const T &operator()(size_t i, size_t j, size_t k) const
{ return operator[](idx(i,j,k)); }
T &v(size_t i)
{ return vraw(idx(i)); }
T &v(size_t i, size_t j)
{ return vraw(idx(i,j)); }
T &v(size_t i, size_t j, size_t k)
{ return vraw(idx(i,j,k)); }
template<typename Func> void apply(Func func) template<typename Func> void apply(Func func)
{ {
T *d2 = vdata(); T *d2 = vdata();
if (contiguous())
{
for (auto &v: {d2, d2+size()})
func(*v);
return;
}
// FIXME: special cases for contiguous arrays and/or zeroing? // FIXME: special cases for contiguous arrays and/or zeroing?
if (ndim==1) if (ndim==1)
for (size_t i=0; i<shp[0]; ++i) for (size_t i=0; i<shp[0]; ++i)
...@@ -309,6 +302,11 @@ template<typename T, size_t ndim> class MavIter ...@@ -309,6 +302,11 @@ template<typename T, size_t ndim> class MavIter
ptrdiff_t idx_; ptrdiff_t idx_;
bool done_; bool done_;
template<typename... Ns> ptrdiff_t getIdx(size_t dim, size_t n, Ns... ns) const
{ return str[dim]*n + getIdx(dim+1, ns...); }
ptrdiff_t getIdx(size_t dim, size_t n) const
{ return str[dim]*n; }
public: public:
MavIter(const fmav<T> &mav_) MavIter(const fmav<T> &mav_)
: mav(mav_), pos(mav.ndim()-ndim,0), idx_(0), done_(false) : mav(mav_), pos(mav.ndim()-ndim,0), idx_(0), done_(false)
...@@ -342,33 +340,12 @@ template<typename T, size_t ndim> class MavIter ...@@ -342,33 +340,12 @@ template<typename T, size_t ndim> class MavIter
done_=true; done_=true;
} }
size_t shape(size_t i) const { return shp[i]; } size_t shape(size_t i) const { return shp[i]; }
ptrdiff_t idx(size_t i) const template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{ { return idx_ + getIdx(0, ns...); }
static_assert(ndim==1, "ndim must be 1"); template<typename... Ns> const T &operator()(Ns... ns) const
return idx_+i*str[0]; { return mav[idx(ns...)]; }
} template<typename... Ns> T &v(Ns... ns)
ptrdiff_t idx(size_t i, size_t j) const { return mav.vraw(idx(ns...)); }
{
static_assert(ndim==2, "ndim must be 2");
return idx_+i*str[0]+j*str[1];
}
ptrdiff_t idx(size_t i, size_t j, size_t k) const
{
static_assert(ndim==3, "ndim must be 3");
return idx_+i*str[0]+j*str[1]+k*str[2];
}
const T &operator()(size_t i) const
{ return mav[idx(i)]; }
const T &operator()(size_t i, size_t j) const
{ return mav[idx(i,j)]; }
const T &operator()(size_t i, size_t j, size_t k) const
{ return mav[idx(i,j,k)]; }
T &v(size_t i)
{ return mav.vraw(idx(i)); }
T &v(size_t i, size_t j)
{ return mav.vraw(idx(i,j)); }
T &v(size_t i, size_t j, size_t k)
{ return mav.vraw(idx(i,j,k)); }
}; };
} }
......
...@@ -52,17 +52,17 @@ namespace mr { ...@@ -52,17 +52,17 @@ namespace mr {
namespace detail_simd { namespace detail_simd {
template<typename T> constexpr bool vectorizable = false; template<typename T> constexpr static const bool vectorizable = false;
template<> constexpr bool vectorizable<float> = true; template<> constexpr static const bool vectorizable<float> = true;
template<> constexpr bool vectorizable<double> = true; template<> constexpr static const bool vectorizable<double> = true;
template<> constexpr bool vectorizable<int8_t> = true; template<> constexpr static const bool vectorizable<int8_t> = true;
template<> constexpr bool vectorizable<uint8_t> = true; template<> constexpr static const bool vectorizable<uint8_t> = true;
template<> constexpr bool vectorizable<int16_t> = true; template<> constexpr static const bool vectorizable<int16_t> = true;
template<> constexpr bool vectorizable<uint16_t> = true; template<> constexpr static const bool vectorizable<uint16_t> = true;
template<> constexpr bool vectorizable<int32_t> = true; template<> constexpr static const bool vectorizable<int32_t> = true;
template<> constexpr bool vectorizable<uint32_t> = true; template<> constexpr static const bool vectorizable<uint32_t> = true;
template<> constexpr bool vectorizable<int64_t> = true; template<> constexpr static const bool vectorizable<int64_t> = true;
template<> constexpr bool vectorizable<uint64_t> = true; template<> constexpr static const bool vectorizable<uint64_t> = true;
template<typename T, size_t reglen> constexpr size_t vlen template<typename T, size_t reglen> constexpr size_t vlen
= vectorizable<T> ? reglen/sizeof(T) : 1; = vectorizable<T> ? reglen/sizeof(T) : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment