Commit e2b8a7ca authored by Martin Reinecke's avatar Martin Reinecke
Browse files

start using fmav in FFT

parent 00c8d288
......@@ -87,3 +87,4 @@ libtool
ar-lib
*-uninstalled.sh
missing
test-driver
......@@ -243,7 +243,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer1_geom_info (size_t nrings, size_t p
weight[2*k ]=2./(1.-4.*k*k)*sin((k*pi)/nrings);
}
if ((nrings&1)==0) weight[nrings-1]=0.;
mr::r2r_fftpack({size_t(nrings)}, {sizeof(double)}, {sizeof(double)}, {0}, false, false, weight.data(), weight.data(), 1.);
mr::r2r_fftpack({size_t(nrings)}, {1}, {1}, {0}, false, false, weight.data(), weight.data(), 1.);
for (size_t m=0; m<(nrings+1)/2; ++m)
{
......@@ -275,7 +275,7 @@ unique_ptr<sharp_geom_info> sharp_make_cc_geom_info (size_t nrings, size_t pprin
for (size_t k=1; k<=(n/2-1); ++k)
weight[2*k-1]=2./(1.-4.*k*k) + dw;
weight[2*(n/2)-1]=(n-3.)/(2*(n/2)-1) -1. -dw*((2-(n&1))*n-1);
mr::r2r_fftpack({size_t(n)}, {sizeof(double)}, {sizeof(double)}, {0}, false, false, weight.data(), weight.data(), 1.);
mr::r2r_fftpack({size_t(n)}, {1}, {1}, {0}, false, false, weight.data(), weight.data(), 1.);
weight[n]=weight[0];
for (size_t m=0; m<(nrings+1)/2; ++m)
......@@ -308,7 +308,7 @@ unique_ptr<sharp_geom_info> sharp_make_fejer2_geom_info (size_t nrings, size_t p
for (size_t k=1; k<=(n/2-1); ++k)
weight[2*k-1]=2./(1.-4.*k*k);
weight[2*(n/2)-1]=(n-3.)/(2*(n/2)-1) -1.;
mr::r2r_fftpack({size_t(n)}, {sizeof(double)}, {sizeof(double)}, {0}, false, false, weight.data(), weight.data(), 1.);
mr::r2r_fftpack({size_t(n)}, {1}, {1}, {0}, false, false, weight.data(), weight.data(), 1.);
for (size_t m=0; m<nrings; ++m)
weight[m]=weight[m+1];
......
......@@ -60,6 +60,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "mr_util/unity_roots.h"
#include "mr_util/useful_macros.h"
#include "mr_util/simd.h"
#include "mr_util/mav.h"
#ifndef MRUTIL_NO_THREADING
#include <mutex>
#endif
......@@ -76,9 +77,6 @@ template <typename T> T cos(T) = delete;
template <typename T> T sin(T) = delete;
template <typename T> T sqrt(T) = delete;
using shape_t = std::vector<size_t>;
using stride_t = std::vector<ptrdiff_t>;
constexpr bool FORWARD = true,
BACKWARD = false;
......@@ -2293,51 +2291,11 @@ template<typename T> std::shared_ptr<T> get_plan(size_t length)
#endif
}
class arr_info
{
protected:
shape_t shp;
stride_t str;
public:
arr_info(const shape_t &shape_, const stride_t &stride_)
: shp(shape_), str(stride_) {}
size_t ndim() const { return shp.size(); }
size_t size() const { return util::prod(shp); }
const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; }
const stride_t &stride() const { return str; }
const ptrdiff_t &stride(size_t i) const { return str[i]; }
};
template<typename T> class cndarr: public arr_info
{
protected:
const char *d;
public:
cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)
: arr_info(shape_, stride_),
d(reinterpret_cast<const char *>(data_)) {}
const T &operator[](ptrdiff_t ofs) const
{ return *reinterpret_cast<const T *>(d+ofs); }
};
template<typename T> class ndarr: public cndarr<T>
{
public:
ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)
: cndarr<T>::cndarr(const_cast<const void *>(data_), shape_, stride_)
{}
T &operator[](ptrdiff_t ofs)
{ return *reinterpret_cast<T *>(const_cast<char *>(cndarr<T>::d+ofs)); }
};
template<size_t N> class multi_iter
{
private:
shape_t pos;
const arr_info &iarr, &oarr;
fmav_info iarr, oarr;
ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;
size_t idim, rem;
......@@ -2358,7 +2316,7 @@ template<size_t N> class multi_iter
}
public:
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_,
multi_iter(const fmav_info &iarr_, const fmav_info &oarr_, size_t idim_,
size_t nshares, size_t myshare)
: pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
......@@ -2412,12 +2370,12 @@ class simple_iter
{
private:
shape_t pos;
const arr_info &arr;
fmav_info arr;
ptrdiff_t p;
size_t rem;
public:
simple_iter(const arr_info &arr_)
simple_iter(const fmav_info &arr_)
: pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
void advance()
{
......@@ -2440,7 +2398,7 @@ class rev_iter
{
private:
shape_t pos;
const arr_info &arr;
fmav_info arr;
std::vector<char> rev_axis;
std::vector<char> rev_jump;
size_t last_axis, last_size;
......@@ -2449,7 +2407,7 @@ class rev_iter
size_t rem;
public:
rev_iter(const arr_info &arr_, const shape_t &axes)
rev_iter(const fmav_info &arr_, const shape_t &axes)
: pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
rev_jump(arr_.ndim(), 1), p(0), rp(0)
{
......@@ -2517,15 +2475,15 @@ template<> struct VTYPE<long double>
};
#endif
template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape,
size_t axsize, size_t elemsize)
template<typename T, typename T0> aligned_array<T> alloc_tmp(const shape_t &shape,
size_t axsize)
{
auto othersize = util::prod(shape)/axsize;
auto tmpsize = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
return aligned_array<char>(tmpsize*elemsize);
auto tmpsize = axsize*((othersize>=VLEN<T0>::val) ? VLEN<T0>::val : 1);
return aligned_array<T>(tmpsize);
}
template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape,
const shape_t &axes, size_t elemsize)
template<typename T, typename T0> aligned_array<T> alloc_tmp(const shape_t &shape,
const shape_t &axes)
{
size_t fullsize=util::prod(shape);
size_t tmpsize=0;
......@@ -2533,14 +2491,14 @@ template<typename T> aligned_array<char> alloc_tmp(const shape_t &shape,
{
auto axsize = shape[axes[i]];
auto othersize = fullsize/axsize;
auto sz = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
auto sz = axsize*((othersize>=VLEN<T0>::val) ? VLEN<T0>::val : 1);
if (sz>tmpsize) tmpsize=sz;
}
return aligned_array<char>(tmpsize*elemsize);
return aligned_array<T>(tmpsize);
}
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cndarr<Cmplx<T>> &src, Cmplx<vtype_t<T>> *MRUTIL_RESTRICT dst)
const cfmav<Cmplx<T>> &src, Cmplx<vtype_t<T>> *MRUTIL_RESTRICT dst)
{
for (size_t i=0; i<it.length_in(); ++i)
for (size_t j=0; j<vlen; ++j)
......@@ -2551,7 +2509,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cndarr<T> &src, vtype_t<T> *MRUTIL_RESTRICT dst)
const cfmav<T> &src, vtype_t<T> *MRUTIL_RESTRICT dst)
{
for (size_t i=0; i<it.length_in(); ++i)
for (size_t j=0; j<vlen; ++j)
......@@ -2559,7 +2517,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
const cndarr<T> &src, T *MRUTIL_RESTRICT dst)
const cfmav<T> &src, T *MRUTIL_RESTRICT dst)
{
if (dst == &src[it.iofs(0)]) return; // in-place
for (size_t i=0; i<it.length_in(); ++i)
......@@ -2567,7 +2525,7 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
}
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const Cmplx<vtype_t<T>> *MRUTIL_RESTRICT src, ndarr<Cmplx<T>> &dst)
const Cmplx<vtype_t<T>> *MRUTIL_RESTRICT src, const fmav<Cmplx<T>> &dst)
{
for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j)
......@@ -2575,7 +2533,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
}
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const vtype_t<T> *MRUTIL_RESTRICT src, ndarr<T> &dst)
const vtype_t<T> *MRUTIL_RESTRICT src, const fmav<T> &dst)
{
for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j)
......@@ -2583,7 +2541,7 @@ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
}
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, ndarr<T> &dst)
const T *MRUTIL_RESTRICT src, const fmav<T> &dst)
{
if (src == &dst[it.oofs(0)]) return; // in-place
for (size_t i=0; i<it.length_out(); ++i)
......@@ -2596,7 +2554,7 @@ template <typename T> struct add_vec<Cmplx<T>>
template <typename T> using add_vec_t = typename add_vec<T>::type;
template<typename Tplan, typename T, typename T0, typename Exec>
MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out,
const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
const bool allow_inplace=true)
{
......@@ -2612,8 +2570,8 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
[&](Scheduler &sched) {
constexpr auto vlen = VLEN<T0>::val;
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
const auto &tin(iax==0? in : out);
auto storage = alloc_tmp<T,T0>(in.shape(), len);
const auto &tin(iax==0? in : cfmav<T>(out));
multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
......@@ -2627,7 +2585,7 @@ MRUTIL_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
while (it.remaining()>0)
{
it.advance(1);
auto buf = allow_inplace && it.stride_out() == sizeof(T) ?
auto buf = allow_inplace && it.stride_out() == 1 ?
&out[it.oofs(0)] : reinterpret_cast<T *>(storage.data());
exec(it, tin, out, buf, *plan, fct);
}
......@@ -2641,8 +2599,8 @@ struct ExecC2C
bool forward;
template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<Cmplx<T0>> &in,
ndarr<Cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const
const multi_iter<vlen> &it, const cfmav<Cmplx<T0>> &in,
const fmav<Cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const
{
copy_input(it, in, buf);
plan.exec(buf, fct, forward);
......@@ -2651,7 +2609,7 @@ struct ExecC2C
};
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const vtype_t<T> *MRUTIL_RESTRICT src, ndarr<T> &dst)
const vtype_t<T> *MRUTIL_RESTRICT src, const fmav<T> &dst)
{
for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,0)] = src[0][j];
......@@ -2668,7 +2626,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
}
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, ndarr<T> &dst)
const T *MRUTIL_RESTRICT src, const fmav<T> &dst)
{
dst[it.oofs(0)] = src[0];
size_t i=1, i1=1, i2=it.length_out()-1;
......@@ -2684,7 +2642,7 @@ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
struct ExecHartley
{
template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out,
const multi_iter<vlen> &it, const cfmav<T0> &in, const fmav<T0> &out,
T * buf, const pocketfft_r<T0> &plan, T0 fct) const
{
copy_input(it, in, buf);
......@@ -2700,8 +2658,8 @@ struct ExecDcst
bool cosine;
template <typename T0, typename T, typename Tplan, size_t vlen>
void operator () (const multi_iter<vlen> &it, const cndarr<T0> &in,
ndarr<T0> &out, T * buf, const Tplan &plan, T0 fct) const
void operator () (const multi_iter<vlen> &it, const cfmav<T0> &in,
const fmav <T0> &out, T * buf, const Tplan &plan, T0 fct) const
{
copy_input(it, in, buf);
plan.exec(buf, fct, ortho, type, cosine);
......@@ -2710,7 +2668,7 @@ struct ExecDcst
};
template<typename T> MRUTIL_NOINLINE void general_r2c(
const cndarr<T> &in, ndarr<Cmplx<T>> &out, size_t axis, bool forward, T fct,
const cfmav<T> &in, const fmav<Cmplx<T>> &out, size_t axis, bool forward, T fct,
size_t nthreads)
{
auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
......@@ -2719,7 +2677,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&](Scheduler &sched) {
constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
auto storage = alloc_tmp<T,T>(in.shape(), len);
multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
......@@ -2765,7 +2723,7 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
}); // end of parallel region
}
template<typename T> MRUTIL_NOINLINE void general_c2r(
const cndarr<Cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
const cfmav<Cmplx<T>> &in, const fmav<T> &out, size_t axis, bool forward, T fct,
size_t nthreads)
{
auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
......@@ -2774,7 +2732,7 @@ template<typename T> MRUTIL_NOINLINE void general_c2r(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&](Scheduler &sched) {
constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
auto storage = alloc_tmp<T,T>(out.shape(), len);
multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
......@@ -2841,7 +2799,7 @@ struct ExecR2R
bool r2c, forward;
template <typename T0, typename T, size_t vlen> void operator () (
const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out, T * buf,
const multi_iter<vlen> &it, const cfmav<T0> &in, const fmav<T0> &out, T * buf,
const pocketfft_r<T0> &plan, T0 fct) const
{
copy_input(it, in, buf);
......@@ -2863,8 +2821,8 @@ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<Cmplx<T>> ain(data_in, shape, stride_in);
ndarr<Cmplx<T>> aout(data_out, shape, stride_out);
const cfmav<Cmplx<T>> ain(reinterpret_cast<const Cmplx<T> *>(data_in), shape, stride_in);
const fmav<Cmplx<T>> aout(reinterpret_cast<Cmplx<T> *>(data_out), shape, stride_out);
general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward});
}
......@@ -2875,8 +2833,8 @@ template<typename T> void dct(const shape_t &shape,
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
const ExecDcst exec{ortho, type, true};
if (type==1)
general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);
......@@ -2893,8 +2851,8 @@ template<typename T> void dst(const shape_t &shape,
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
const ExecDcst exec{ortho, type, false};
if (type==1)
general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);
......@@ -2911,10 +2869,10 @@ template<typename T> void r2c(const shape_t &shape_in,
{
if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axis);
cndarr<T> ain(data_in, shape_in, stride_in);
const cfmav<T> ain(data_in, shape_in, stride_in);
shape_t shape_out(shape_in);
shape_out[axis] = shape_in[axis]/2 + 1;
ndarr<Cmplx<T>> aout(data_out, shape_out, stride_out);
const fmav<Cmplx<T>> aout(reinterpret_cast<Cmplx<T> *>(data_out), shape_out, stride_out);
general_r2c(ain, aout, axis, forward, fct, nthreads);
}
......@@ -2945,8 +2903,8 @@ template<typename T> void c2r(const shape_t &shape_out,
util::sanity_check(shape_out, stride_in, stride_out, false, axis);
shape_t shape_in(shape_out);
shape_in[axis] = shape_out[axis]/2 + 1;
cndarr<Cmplx<T>> ain(data_in, shape_in, stride_in);
ndarr<T> aout(data_out, shape_out, stride_out);
const cfmav<Cmplx<T>> ain(reinterpret_cast<const Cmplx<T> *>(data_in), shape_in, stride_in);
const fmav<T> aout(data_out, shape_out, stride_out);
general_c2r(ain, aout, axis, forward, fct, nthreads);
}
......@@ -2964,7 +2922,7 @@ template<typename T> void c2r(const shape_t &shape_out,
shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;
auto nval = util::prod(shape_in);
stride_t stride_inter(shape_in.size());
stride_inter.back() = sizeof(Cmplx<T>);
stride_inter.back() = 1;
for (int i=int(shape_in.size())-2; i>=0; --i)
stride_inter[size_t(i)] =
stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]);
......@@ -2983,8 +2941,8 @@ template<typename T> void r2r_fftpack(const shape_t &shape,
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads,
ExecR2R{real2hermitian, forward});
}
......@@ -2995,8 +2953,8 @@ template<typename T> void r2r_separable_hartley(const shape_t &shape,
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
cndarr<T> ain(data_in, shape, stride_in);
ndarr<T> aout(data_out, shape, stride_out);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},
false);
}
......@@ -3014,12 +2972,12 @@ template<typename T> void r2r_genuine_hartley(const shape_t &shape,
tshp[axes.back()] = tshp[axes.back()]/2+1;
aligned_array<std::complex<T>> tdata(util::prod(tshp));
stride_t tstride(shape.size());
tstride.back()=sizeof(std::complex<T>);
tstride.back()=1;
for (size_t i=tstride.size()-1; i>0; --i)
tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]);
r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);
cndarr<Cmplx<T>> atmp(tdata.data(), tshp, tstride);
ndarr<T> aout(data_out, shape, stride_out);
const cfmav<Cmplx<T>> atmp(reinterpret_cast<const Cmplx<T> *>(tdata.data()), tshp, tstride);
const fmav<T> aout(data_out, shape, stride_out);
simple_iter iin(atmp);
rev_iter iout(aout, axes);
while(iin.remaining()>0)
......@@ -3035,8 +2993,6 @@ template<typename T> void r2r_genuine_hartley(const shape_t &shape,
using detail_fft::FORWARD;
using detail_fft::BACKWARD;
using detail_fft::shape_t;
using detail_fft::stride_t;
using detail_fft::c2c;
using detail_fft::c2r;
using detail_fft::r2c;
......
......@@ -24,6 +24,8 @@
#include <cstdlib>
#include <array>
#include <vector>
#include "mr_util/error_handling.h"
namespace mr {
......@@ -31,7 +33,100 @@ namespace detail_mav {
using namespace std;
using shape_t = vector<size_t>;
using stride_t = vector<ptrdiff_t>;
class fmav_info
{
protected:
shape_t shp;
stride_t str;
static size_t prod(const shape_t &shape)
{
size_t res=1;
for (auto sz: shape)
res*=sz;
return res;
}
public:
fmav_info(const shape_t &shape_, const stride_t &stride_)
: shp(shape_), str(stride_)
{ MR_assert(shp.size()==str.size(), "dimensions mismatch"); }
fmav_info(const shape_t &shape_)
: shp(shape_), str(shape_.size())
{
auto ndim = shp.size();
str[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
}
size_t ndim() const { return shp.size(); }
size_t size() const { return prod(shp); }
const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; }
const stride_t &stride() const { return str; }
const ptrdiff_t &stride(size_t i) const { return str[i]; }
bool last_contiguous() const
{ return (str.back()==1); }
bool contiguous() const
{
auto ndim = shp.size();
ptrdiff_t stride=1;
for (size_t i=0; i<ndim; ++i)
{
if (str[ndim-1-i]!=stride) return false;
stride *= shp[ndim-1-i];
}
return true;
}
bool conformable(const fmav_info &other) const
{ return shp==other.shp; }
};
// "mav" stands for "multidimensional array view"
template<typename T> class cfmav: public fmav_info
{
protected:
T *d;
public:
cfmav(const T *d_, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), d(const_cast<T *>(d_)) {}
cfmav(const T *d_, const shape_t &shp_)
: fmav_info(shp_), d(const_cast<T *>(d_)) {}
template<typename I> const T &operator[](I i) const
{ return d[i]; }
const T *data() const
{ return d; }
};
template<typename T> class fmav: public cfmav<T>
{
protected:
using parent = cfmav<T>;
using parent::d;
using parent::shp;
using parent::str;
public:
fmav(T *d_, const shape_t &shp_, const stride_t &str_)
: parent(d_, shp_, str_) {}
fmav(T *d_, const shape_t &shp_)
: parent(d_, shp_) {}
template<typename I> T &operator[](I i) const
{ return d[i]; }
using parent::shape;
using parent::stride;
T *data() const
{ return d; }
using parent::last_contiguous;
using parent::contiguous;
using parent::conformable;
};
template<typename T> using const_fmav = fmav<const T>;
template<typename T, size_t ndim> class mav
{
static_assert((ndim>0) && (ndim<3), "only supports 1D and 2D arrays");
......@@ -49,8 +144,8 @@ template<typename T, size_t ndim> class mav
: d(d_), shp(shp_)
{
str[ndim-1]=1;
for (size_t d=2; d<=ndim; ++d)
str[ndim-d] = str[ndim-d+1]*shp[ndim-d+1];
for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*shp[ndim-i+1];
}
T &operator[](size_t i) const
{ return operator()(i); }
......@@ -98,6 +193,8 @@ template<typename T, size_t ndim> class mav
for (size_t j=0; j<shp[1]; ++j)
d[str[0]*i + str[1]*j] = val;
}
template<typename T2> bool conformable(const mav<T2,ndim> &other) const
{ return shp==other.shp; }
};
template<typename T, size_t ndim> using const_mav = mav<const T, ndim>;
......@@ -112,6 +209,12 @@ template<typename T, size_t ndim> const_mav<T, ndim> nullmav()
}
using detail_mav::shape_t;
using detail_mav::stride_t;
using detail_mav::fmav_info;
using detail_mav::fmav;
using detail_mav::cfmav;
using detail_mav::const_fmav;
using detail_mav::mav;