Commit 33251a05 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

stage 2

parent d73de808
...@@ -498,9 +498,9 @@ template<typename I> template<typename I2> ...@@ -498,9 +498,9 @@ template<typename I> template<typename I2>
double dr=base[o].max_pixrad(); // safety distance double dr=base[o].max_pixrad(); // safety distance
for (size_t i=0; i<nv; ++i) for (size_t i=0; i<nv; ++i)
{ {
crlimit(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr); crlimit.v(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr);
crlimit(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1); crlimit.v(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1);
crlimit(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr); crlimit.v(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr);
} }
} }
...@@ -563,9 +563,9 @@ template<typename I> void T_Healpix_Base<I>::query_multidisc_general ...@@ -563,9 +563,9 @@ template<typename I> void T_Healpix_Base<I>::query_multidisc_general
double dr=base[o].max_pixrad(); // safety distance double dr=base[o].max_pixrad(); // safety distance
for (size_t i=0; i<nv; ++i) for (size_t i=0; i<nv; ++i)
{ {
crlimit(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr); crlimit.v(o,i,0) = (rad[i]+dr>pi) ? -1. : cos(rad[i]+dr);
crlimit(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1); crlimit.v(o,i,1) = (o==0) ? cos(rad[i]) : crlimit(0,i,1);
crlimit(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr); crlimit.v(o,i,2) = (rad[i]-dr<0.) ? 1. : cos(rad[i]-dr);
} }
} }
......
...@@ -650,25 +650,28 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it, ...@@ -650,25 +650,28 @@ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const Cmplx<native_simd<T>> *MRUTIL_RESTRICT src, fmav<Cmplx<T>> &dst) const Cmplx<native_simd<T>> *MRUTIL_RESTRICT src, fmav<Cmplx<T>> &dst)
{ {
auto ptr=dst.vdata();
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]); ptr[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]);
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst) const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
auto ptr=dst.vdata();
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,i)] = src[i][j]; ptr[it.oofs(j,i)] = src[i][j];
} }
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it, template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, fmav<T> &dst) const T *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
auto ptr=dst.vdata();
if (src == &dst[it.oofs(0)]) return; // in-place if (src == &dst[it.oofs(0)]) return; // in-place
for (size_t i=0; i<it.length_out(); ++i) for (size_t i=0; i<it.length_out(); ++i)
dst[it.oofs(i)] = src[i]; ptr[it.oofs(i)] = src[i];
} }
template <typename T> struct add_vec { using type = native_simd<T>; }; template <typename T> struct add_vec { using type = native_simd<T>; };
...@@ -709,7 +712,7 @@ MRUTIL_NOINLINE void general_nd(const fmav<T> &in, fmav<T> &out, ...@@ -709,7 +712,7 @@ MRUTIL_NOINLINE void general_nd(const fmav<T> &in, fmav<T> &out,
{ {
it.advance(1); it.advance(1);
auto buf = allow_inplace && it.stride_out() == 1 ? auto buf = allow_inplace && it.stride_out() == 1 ?
&out[it.oofs(0)] : reinterpret_cast<T *>(storage.data()); &out.vraw(it.oofs(0)) : reinterpret_cast<T *>(storage.data());
exec(it, tin, out, buf, *plan, fct); exec(it, tin, out, buf, *plan, fct);
} }
}); // end of parallel region }); // end of parallel region
...@@ -734,32 +737,34 @@ struct ExecC2C ...@@ -734,32 +737,34 @@ struct ExecC2C
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst) const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
auto ptr = dst.vdata();
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,0)] = src[0][j]; ptr[it.oofs(j,0)] = src[0][j];
size_t i=1, i1=1, i2=it.length_out()-1; size_t i=1, i1=1, i2=it.length_out()-1;
for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2) for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
{ {
dst[it.oofs(j,i1)] = src[i][j]+src[i+1][j]; ptr[it.oofs(j,i1)] = src[i][j]+src[i+1][j];
dst[it.oofs(j,i2)] = src[i][j]-src[i+1][j]; ptr[it.oofs(j,i2)] = src[i][j]-src[i+1][j];
} }
if (i<it.length_out()) if (i<it.length_out())
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
dst[it.oofs(j,i1)] = src[i][j]; ptr[it.oofs(j,i1)] = src[i][j];
} }
template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it, template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
const T *MRUTIL_RESTRICT src, fmav<T> &dst) const T *MRUTIL_RESTRICT src, fmav<T> &dst)
{ {
dst[it.oofs(0)] = src[0]; auto ptr = dst.vdata();
ptr[it.oofs(0)] = src[0];
size_t i=1, i1=1, i2=it.length_out()-1; size_t i=1, i1=1, i2=it.length_out()-1;
for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2) for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
{ {
dst[it.oofs(i1)] = src[i]+src[i+1]; ptr[it.oofs(i1)] = src[i]+src[i+1];
dst[it.oofs(i2)] = src[i]-src[i+1]; ptr[it.oofs(i2)] = src[i]-src[i+1];
} }
if (i<it.length_out()) if (i<it.length_out())
dst[it.oofs(i1)] = src[i]; ptr[it.oofs(i1)] = src[i];
} }
struct ExecHartley struct ExecHartley
...@@ -810,20 +815,21 @@ template<typename T> MRUTIL_NOINLINE void general_r2c( ...@@ -810,20 +815,21 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data()); auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
copy_input(it, in, tdatav); copy_input(it, in, tdatav);
plan->exec(tdatav, fct, true); plan->exec(tdatav, fct, true);
auto vout = out.vdata();
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,0)].Set(tdatav[0][j]); vout[it.oofs(j,0)].Set(tdatav[0][j]);
size_t i=1, ii=1; size_t i=1, ii=1;
if (forward) if (forward)
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]); vout[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
else else
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]); vout[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
if (i<len) if (i<len)
for (size_t j=0; j<vlen; ++j) for (size_t j=0; j<vlen; ++j)
out[it.oofs(j,ii)].Set(tdatav[i][j]); vout[it.oofs(j,ii)].Set(tdatav[i][j]);
} }
#endif #endif
while (it.remaining()>0) while (it.remaining()>0)
...@@ -832,16 +838,17 @@ template<typename T> MRUTIL_NOINLINE void general_r2c( ...@@ -832,16 +838,17 @@ template<typename T> MRUTIL_NOINLINE void general_r2c(
auto tdata = reinterpret_cast<T *>(storage.data()); auto tdata = reinterpret_cast<T *>(storage.data());
copy_input(it, in, tdata); copy_input(it, in, tdata);
plan->exec(tdata, fct, true); plan->exec(tdata, fct, true);
out[it.oofs(0)].Set(tdata[0]); auto vout = out.vdata();
vout[it.oofs(0)].Set(tdata[0]);
size_t i=1, ii=1; size_t i=1, ii=1;
if (forward) if (forward)
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], tdata[i+1]); vout[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
else else
for (; i<len-1; i+=2, ++ii) for (; i<len-1; i+=2, ++ii)
out[it.oofs(ii)].Set(tdata[i], -tdata[i+1]); vout[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
if (i<len) if (i<len)
out[it.oofs(ii)].Set(tdata[i]); vout[it.oofs(ii)].Set(tdata[i]);
} }
}); // end of parallel region }); // end of parallel region
} }
...@@ -944,7 +951,7 @@ template<typename T> void c2c(const fmav<std::complex<T>> &in, ...@@ -944,7 +951,7 @@ template<typename T> void c2c(const fmav<std::complex<T>> &in,
util::sanity_check_onetype(in, out, in.data()==out.data(), axes); util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
if (in.size()==0) return; if (in.size()==0) return;
fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in); fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.data()), out, out.writable()); fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward}); general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward});
} }
...@@ -983,7 +990,7 @@ template<typename T> void r2c(const fmav<T> &in, ...@@ -983,7 +990,7 @@ template<typename T> void r2c(const fmav<T> &in,
{ {
util::sanity_check_cr(out, in, axis); util::sanity_check_cr(out, in, axis);
if (in.size()==0) return; if (in.size()==0) return;
fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.data()), out, out.writable()); fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
general_r2c(in, out2, axis, forward, fct, nthreads); general_r2c(in, out2, axis, forward, fct, nthreads);
} }
...@@ -1055,11 +1062,12 @@ template<typename T> void r2r_genuine_hartley(const fmav<T> &in, ...@@ -1055,11 +1062,12 @@ template<typename T> void r2r_genuine_hartley(const fmav<T> &in,
r2c(in, atmp, axes, true, fct, nthreads); r2c(in, atmp, axes, true, fct, nthreads);
simple_iter iin(atmp); simple_iter iin(atmp);
rev_iter iout(out, axes); rev_iter iout(out, axes);
auto vout = out.vdata();
while(iin.remaining()>0) while(iin.remaining()>0)
{ {
auto v = atmp[iin.ofs()]; auto v = atmp[iin.ofs()];
out[iout.ofs()] = v.real()+v.imag(); vout[iout.ofs()] = v.real()+v.imag();
out[iout.rev_ofs()] = v.real()-v.imag(); vout[iout.rev_ofs()] = v.real()-v.imag();
iin.advance(); iout.advance(); iin.advance(); iout.advance();
} }
} }
......
...@@ -107,33 +107,27 @@ template<typename T> class membuf ...@@ -107,33 +107,27 @@ template<typename T> class membuf
membuf(size_t sz) membuf(size_t sz)
: ptr(make_unique<vector<T>>(sz)), d(ptr->data()), rw(true) {} : ptr(make_unique<vector<T>>(sz)), d(ptr->data()), rw(true) {}
membuf(const membuf &other) = default; membuf(const membuf &other) = default;
membuf(membuf &other) = default;
membuf(membuf &&other) = default; membuf(membuf &&other) = default;
// Not for public use! // Not for public use!
membuf(T *d_, const Tsp &p, bool rw_) membuf(T *d_, const Tsp &p, bool rw_)
: ptr(p), d(d_), rw(rw_) {} : ptr(p), d(d_), rw(rw_) {}
template<typename I> const T &val(I i) const template<typename I> T &vraw(I i)
{ return d[i]; }
template<typename I> T &val(I i)
{ {
MR_assert(rw, "array is nor writable"); MR_assert(rw, "array is not writable");
return const_cast<T *>(d)[i]; return const_cast<T *>(d)[i];
} }
template<typename I> const T &operator[](I i) const template<typename I> const T &operator[](I i) const
{ return d[i]; } { return d[i]; }
template<typename I> T &operator[](I i)
{
MR_assert(rw, "array is nor writable");
return const_cast<T *>(d)[i];
}
const T *data() const const T *data() const
{ return d; } { return d; }
T *data() T *vdata()
{ {
MR_assert(rw, "array is nor writable"); MR_assert(rw, "array is not writable");
return const_cast<T *>(d); return const_cast<T *>(d);
} }
bool writable() const { return rw; } bool writable() { return rw; }
}; };
// "mav" stands for "multidimensional array view" // "mav" stands for "multidimensional array view"
...@@ -157,11 +151,11 @@ template<typename T> class fmav: public fmav_info, public membuf<T> ...@@ -157,11 +151,11 @@ template<typename T> class fmav: public fmav_info, public membuf<T>
: fmav_info(info), membuf<T>(d_, rw_) {} : fmav_info(info), membuf<T>(d_, rw_) {}
fmav(const T* d_, const fmav_info &info) fmav(const T* d_, const fmav_info &info)
: fmav_info(info), membuf<T>(d_) {} : fmav_info(info), membuf<T>(d_) {}
fmav(const fmav &other) = delete; fmav(const fmav &other) = default;
fmav(fmav &&other) = default; fmav(fmav &&other) = default;
// Not for public use! // Not for public use!
fmav(T *d_, const Tsp &p, const shape_t &shp_, const stride_t &str_, bool rw_) fmav(membuf<T> &buf, const shape_t &shp_, const stride_t &str_)
: fmav_info(shp_, str_), membuf<T>(d_, p, rw_) {} : fmav_info(shp_, str_), membuf<T>(buf) {}
}; };
template<size_t ndim> class mav_info template<size_t ndim> class mav_info
...@@ -194,7 +188,6 @@ template<size_t ndim> class mav_info ...@@ -194,7 +188,6 @@ template<size_t ndim> class mav_info
for (size_t i=2; i<=ndim; ++i) for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]); str[ndim-i] = str[ndim-i+1]*ptrdiff_t(shp[ndim-i+1]);
} }
// size_t ndim() const { return shp.size(); }
size_t size() const { return sz; } size_t size() const { return sz; }
const shape_t &shape() const { return shp; } const shape_t &shape() const { return shp; }
size_t shape(size_t i) const { return shp[i]; } size_t shape(size_t i) const { return shp[i]; }
...@@ -214,6 +207,21 @@ template<size_t ndim> class mav_info ...@@ -214,6 +207,21 @@ template<size_t ndim> class mav_info
} }
bool conformable(const mav_info &other) const bool conformable(const mav_info &other) const
{ return shp==other.shp; } { return shp==other.shp; }
ptrdiff_t idx(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return str[0]*i;
}
ptrdiff_t idx(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return str[0]*i + str[1]*j;
}
ptrdiff_t idx(size_t i, size_t j, size_t k) const
{
static_assert(ndim==3, "ndim must be 3");
return str[0]*i + str[1]*j + str[2]*k;
}
}; };
template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membuf<T> template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membuf<T>
...@@ -223,16 +231,21 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -223,16 +231,21 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using typename mav_info<ndim>::stride_t; using typename mav_info<ndim>::stride_t;
protected: protected:
using membuf<T>::Tsp; // using membuf<T>::Tsp;
using mav_info<ndim>::size;
using membuf<T>::d; using membuf<T>::d;
using membuf<T>::ptr; using membuf<T>::ptr;
using mav_info<ndim>::shp; using mav_info<ndim>::shp;
using mav_info<ndim>::str; using mav_info<ndim>::str;
using membuf<T>::rw; using membuf<T>::rw;
using membuf<T>::val; using membuf<T>::vraw;
public: public:
using membuf<T>::operator[];
using membuf<T>::vdata;
using membuf<T>::data;
using mav_info<ndim>::size;
using mav_info<ndim>::idx;
mav(const T *d_, const shape_t &shp_, const stride_t &str_) mav(const T *d_, const shape_t &shp_, const stride_t &str_)
: mav_info<ndim>(shp_, str_), membuf<T>(d_) {} : mav_info<ndim>(shp_, str_), membuf<T>(d_) {}
mav(T *d_, const shape_t &shp_, const stride_t &str_, bool rw_=false) mav(T *d_, const shape_t &shp_, const stride_t &str_, bool rw_=false)
...@@ -243,58 +256,44 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -243,58 +256,44 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
: mav_info<ndim>(shp_), membuf<T>(d_, rw_) {} : mav_info<ndim>(shp_), membuf<T>(d_, rw_) {}
mav(const array<size_t,ndim> &shp_) mav(const array<size_t,ndim> &shp_)
: mav_info<ndim>(shp_), membuf<T>(size()) {} : mav_info<ndim>(shp_), membuf<T>(size()) {}
mav(const mav &other) = delete; mav(const mav &other) = default;
mav(mav &&other) = default; mav(mav &&other) = default;
operator fmav<T>() const operator fmav<T>() const
{ {
return fmav<T>(const_cast<T *>(d), ptr, {shp.begin(), shp.end()}, {str.begin(), str.end()}, rw); return fmav<T>(data(), {shp.begin(), shp.end()}, {str.begin(), str.end()});
} }
const T &operator()(size_t i) const operator fmav<T>()
{ {
static_assert(ndim==1, "ndim must be 1"); return fmav<T>(*this, {shp.begin(), shp.end()}, {str.begin(), str.end()});
return val(str[0]*i);
} }
const T &operator()(size_t i) const
{ return operator[](idx(i)); }
const T &operator()(size_t i, size_t j) const const T &operator()(size_t i, size_t j) const
{ { return operator[](idx(i,j)); }
static_assert(ndim==2, "ndim must be 2");
return val(str[0]*i + str[1]*j);
}
const T &operator()(size_t i, size_t j, size_t k) const const T &operator()(size_t i, size_t j, size_t k) const
{ { return operator[](idx(i,j,k)); }
static_assert(ndim==3, "ndim must be 3"); T &v(size_t i)
return val(str[0]*i + str[1]*j + str[2]*k); { return vraw(idx(i)); }
} T &v(size_t i, size_t j)
T &operator()(size_t i) { return vraw(idx(i,j)); }
{ T &v(size_t i, size_t j, size_t k)
static_assert(ndim==1, "ndim must be 1"); { return vraw(idx(i,j,k)); }
return val(str[0]*i);
}
T &operator()(size_t i, size_t j)
{
static_assert(ndim==2, "ndim must be 2");
return val(str[0]*i + str[1]*j);
}
T &operator()(size_t i, size_t j, size_t k)
{
static_assert(ndim==3, "ndim must be 3");
return val(str[0]*i + str[1]*j + str[2]*k);
}
void fill(const T &val) void fill(const T &val)
{ {
MR_assert(rw, "array is nor writable"); T *d2 = vdata();
// FIXME: special cases for contiguous arrays and/or zeroing? // FIXME: special cases for contiguous arrays and/or zeroing?
if (ndim==1) if (ndim==1)
for (size_t i=0; i<shp[0]; ++i) for (size_t i=0; i<shp[0]; ++i)
d[str[0]*i]=val; d2[str[0]*i]=val;
else if (ndim==2) else if (ndim==2)
for (size_t i=0; i<shp[0]; ++i) for (size_t i=0; i<shp[0]; ++i)
for (size_t j=0; j<shp[1]; ++j) for (size_t j=0; j<shp[1]; ++j)
d[str[0]*i + str[1]*j] = val; d2[str[0]*i + str[1]*j] = val;
else if (ndim==3) else if (ndim==3)
for (size_t i=0; i<shp[0]; ++i) for (size_t i=0; i<shp[0]; ++i)
for (size_t j=0; j<shp[1]; ++j) for (size_t j=0; j<shp[1]; ++j)
for (size_t k=0; k<shp[2]; ++k) for (size_t k=0; k<shp[2]; ++k)
d[str[0]*i + str[1]*j + str[2]*k] = val; d2[str[0]*i + str[1]*j + str[2]*k] = val;
} }
}; };
...@@ -302,6 +301,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -302,6 +301,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using detail_mav::shape_t; using detail_mav::shape_t;
using detail_mav::stride_t; using detail_mav::stride_t;
using detail_mav::fmav_info;
using detail_mav::fmav; using detail_mav::fmav;
using detail_mav::mav; using detail_mav::mav;
......
This diff is collapsed.
...@@ -40,13 +40,13 @@ template<typename T> py::array ms2dirty_general2(const py::array &uvw_, ...@@ -40,13 +40,13 @@ template<typename T> py::array ms2dirty_general2(const py::array &uvw_,
size_t nv, double epsilon, bool do_wstacking, size_t nthreads, size_t nv, double epsilon, bool do_wstacking, size_t nthreads,
size_t verbosity) size_t verbosity)
{ {
auto uvw = to_cmav<double,2>(uvw_); auto uvw = to_mav<double,2>(uvw_, false);
auto freq = to_cmav<double,1>(freq_); auto freq = to_mav<double,1>(freq_, false);
auto ms = to_cmav<complex<T>,2>(ms_); auto ms = to_mav<complex<T>,2>(ms_, false);
auto wgt = get_optional_const_Pyarr<T>(wgt_, {ms.shape(0),ms.shape(1)}); auto wgt = get_optional_const_Pyarr<T>(wgt_, {ms.shape(0),ms.shape(1)});
auto wgt2 = to_cmav<T,2>(wgt); auto wgt2 = to_mav<T,2>(wgt, false);
auto dirty = make_Pyarr<T>({npix_x,npix_y}); auto dirty = make_Pyarr<T>({npix_x,npix_y});
auto dirty2 = to_mav<T,2>(dirty); auto dirty2 = to_mav<T,2>(dirty, true);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
ms2dirty_general(uvw,freq,ms,wgt2,pixsize_x,pixsize_y,nu,nv,epsilon, ms2dirty_general(uvw,freq,ms,wgt2,pixsize_x,pixsize_y,nu,nv,epsilon,
...@@ -120,13 +120,13 @@ template<typename T> py::array dirty2ms_general2(const py::array &uvw_, ...@@ -120,13 +120,13 @@ template<typename T> py::array dirty2ms_general2(const py::array &uvw_,
double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon, double pixsize_x, double pixsize_y, size_t nu, size_t nv, double epsilon,
bool do_wstacking, size_t nthreads, size_t verbosity) bool do_wstacking, size_t nthreads, size_t verbosity)
{ {
auto uvw = to_cmav<double,2>(uvw_); auto uvw = to_mav<double,2>(uvw_, false);
auto freq = to_cmav<double,1>(freq_); auto freq = to_mav<double,1>(freq_, false);
auto dirty = to_cmav<T,2>(dirty_); auto dirty = to_mav<T,2>(dirty_, false);
auto wgt = get_optional_const_Pyarr<T>(wgt_, {uvw.shape(0),freq.shape(0)}); auto wgt = get_optional_const_Pyarr<T>(wgt_, {uvw.shape(0),freq.shape(0)});
auto wgt2 = to_cmav<T,2>(wgt);