Commit 8705186d authored by Martin Reinecke's avatar Martin Reinecke

more mavs

parent 33590924
This diff is collapsed.
......@@ -77,7 +77,7 @@ class fmav_info
for (size_t i=0; i<ndim; ++i)
{
if (str[ndim-1-i]!=stride) return false;
stride *= shp[ndim-1-i];
stride *= ptrdiff_t(shp[ndim-1-i]);
}
return true;
}
......@@ -96,6 +96,8 @@ template<typename T> class cfmav: public fmav_info
: fmav_info(shp_, str_), d(const_cast<T *>(d_)) {}
cfmav(const T *d_, const shape_t &shp_)
: fmav_info(shp_), d(const_cast<T *>(d_)) {}
cfmav(const T* d_, const fmav_info &info)
: fmav_info(info), d(const_cast<T *>(d_)) {}
template<typename I> const T &operator[](I i) const
{ return d[i]; }
const T *data() const
......@@ -114,6 +116,8 @@ template<typename T> class fmav: public cfmav<T>
: parent(d_, shp_, str_) {}
fmav(T *d_, const shape_t &shp_)
: parent(d_, shp_) {}
fmav(T* d_, const fmav_info &info)
: parent(d_, info) {}
template<typename I> T &operator[](I i) const
{ return d[i]; }
using parent::shape;
......@@ -125,36 +129,40 @@ template<typename T> class fmav: public cfmav<T>
using parent::conformable;
};
template<typename T> using const_fmav = fmav<const T>;
template<typename T, size_t ndim> class mav
template<typename T, size_t ndim> class cmav
{
static_assert((ndim>0) && (ndim<3), "only supports 1D and 2D arrays");
private:
protected:
T *d;
array<size_t, ndim> shp;
array<ptrdiff_t, ndim> str;
public:
mav(T *d_, const array<size_t,ndim> &shp_,
cmav(const T *d_, const array<size_t,ndim> &shp_,
const array<ptrdiff_t,ndim> &str_)
: d(d_), shp(shp_), str(str_) {}
mav(T *d_, const array<size_t,ndim> &shp_)
: d(d_), shp(shp_)
: d(const_cast<T *>(d_)), shp(shp_), str(str_) {}
cmav(const T *d_, const array<size_t,ndim> &shp_)
: d(const_cast<T *>(d_)), shp(shp_)
{
str[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*shp[ndim-i+1];
}
T &operator[](size_t i) const
operator cfmav<T>() const
{
return cfmav<T>(d, {shp.begin(), shp.end()}, {str.begin(), str.end()});
}
const T &operator[](size_t i) const
{ return operator()(i); }
T &operator()(size_t i) const
const T &operator()(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return d[str[0]*i];
}
T &operator()(size_t i, size_t j) const
const T &operator()(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return d[str[0]*i + str[1]*j];
......@@ -168,7 +176,7 @@ template<typename T, size_t ndim> class mav
return res;
}
ptrdiff_t stride(size_t i) const { return str[i]; }
T *data() const
const T *data() const
{ return d; }
bool last_contiguous() const
{ return (str[ndim-1]==1) || (str[ndim-1]==0); }
......@@ -182,6 +190,46 @@ template<typename T, size_t ndim> class mav
}
return true;
}
template<typename T2> bool conformable(const cmav<T2,ndim> &other) const
{ return shp==other.shp; }
};
template<typename T, size_t ndim> class mav: public cmav<T, ndim>
{
protected:
using parent = cmav<T, ndim>;
using parent::d;
using parent::shp;
using parent::str;
public:
mav(T *d_, const array<size_t,ndim> &shp_,
const array<ptrdiff_t,ndim> &str_)
: parent(d_, shp_, str_) {}
mav(T *d_, const array<size_t,ndim> &shp_)
: parent(d_, shp_) {}
operator fmav<T>() const
{
return fmav<T>(d, {shp.begin(), shp.end()}, {str.begin(), str.end()});
}
T &operator[](size_t i) const
{ return operator()(i); }
T &operator()(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return d[str[0]*i];
}
T &operator()(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return d[str[0]*i + str[1]*j];
}
using parent::shape;
using parent::stride;
using parent::size;
T *data() const
{ return d; }
using parent::last_contiguous;
using parent::contiguous;
void fill(const T &val) const
{
// FIXME: special cases for contiguous arrays and/or zeroing?
......@@ -193,20 +241,9 @@ template<typename T, size_t ndim> class mav
for (size_t j=0; j<shp[1]; ++j)
d[str[0]*i + str[1]*j] = val;
}
template<typename T2> bool conformable(const mav<T2,ndim> &other) const
{ return shp==other.shp; }
using parent::conformable;
};
template<typename T, size_t ndim> using const_mav = mav<const T, ndim>;
template<typename T, size_t ndim> const_mav<T, ndim> cmav (const mav<T, ndim> &mav)
{ return const_mav<T, ndim>(mav.data(), mav.shape()); }
template<typename T, size_t ndim> const_mav<T, ndim> nullmav()
{
array<size_t,ndim> shp;
shp.fill(0);
return const_mav<T, ndim>(nullptr, shp);
}
}
using detail_mav::shape_t;
......@@ -214,11 +251,8 @@ using detail_mav::stride_t;
using detail_mav::fmav_info;
using detail_mav::fmav;
using detail_mav::cfmav;
using detail_mav::const_fmav;
using detail_mav::mav;
using detail_mav::const_mav;
using detail_mav::cmav;
using detail_mav::nullmav;
}
......
This diff is collapsed.
......@@ -85,7 +85,7 @@ template<size_t ndim, typename T> mav<T,ndim> make_mav(pyarr<T> &in)
}
return mav<T, ndim>(in.mutable_data(),dims,str);
}
template<size_t ndim, typename T> const_mav<T,ndim> make_const_mav(const pyarr<T> &in)
template<size_t ndim, typename T> cmav<T,ndim> make_cmav(const pyarr<T> &in)
{
MR_assert(ndim==in.ndim(), "dimension mismatch");
array<size_t,ndim> dims;
......@@ -96,7 +96,7 @@ template<size_t ndim, typename T> const_mav<T,ndim> make_const_mav(const pyarr<T
str[i]=in.strides(i)/sizeof(T);
MR_assert(str[i]*ptrdiff_t(sizeof(T))==in.strides(i), "weird strides");
}
return const_mav<T, ndim>(in.data(),dims,str);
return cmav<T, ndim>(in.data(),dims,str);
}
template<typename T> py::array ms2dirty_general2(const py::array &uvw_,
......@@ -105,13 +105,13 @@ template<typename T> py::array ms2dirty_general2(const py::array &uvw_,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
auto uvw = getPyarr<double>(uvw_, "uvw");
auto uvw2 = make_const_mav<2>(uvw);
auto uvw2 = make_cmav<2>(uvw);
auto freq = getPyarr<double>(freq_, "freq");
auto freq2 = make_const_mav<1>(freq);
auto freq2 = make_cmav<1>(freq);
auto ms = getPyarr<complex<T>>(ms_, "ms");
auto ms2 = make_const_mav<2>(ms);
auto ms2 = make_cmav<2>(ms);
auto wgt = providePotentialArray<T>(wgt_, "wgt", {ms2.shape(0),ms2.shape(1)});
auto wgt2 = make_const_mav<2>(wgt);
auto wgt2 = make_cmav<2>(wgt);
auto dirty = makeArray<T>({npix_x,npix_y});
auto dirty2 = make_mav<2>(dirty);
{
......@@ -187,13 +187,13 @@ template<typename T> py::array dirty2ms_general2(const py::array &uvw_,
bool do_wstacking, size_t nthreads, size_t verbosity)
{
auto uvw = getPyarr<double>(uvw_, "uvw");
auto uvw2 = make_const_mav<2>(uvw);
auto uvw2 = make_cmav<2>(uvw);
auto freq = getPyarr<double>(freq_, "freq");
auto freq2 = make_const_mav<1>(freq);
auto freq2 = make_cmav<1>(freq);
auto dirty = getPyarr<T>(dirty_, "dirty");
auto dirty2 = make_const_mav<2>(dirty);
auto dirty2 = make_cmav<2>(dirty);
auto wgt = providePotentialArray<T>(wgt_, "wgt", {uvw2.shape(0),freq2.shape(0)});
auto wgt2 = make_const_mav<2>(wgt);
auto wgt2 = make_cmav<2>(wgt);
auto ms = makeArray<complex<T>>({uvw2.shape(0),freq2.shape(0)});
auto ms2 = make_mav<2>(ms);
{
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment