Commit 8705186d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more mavs

parent 33590924
......@@ -222,24 +222,10 @@ struct util // hack to avoid duplicate symbols
return res;
}
static MRUTIL_NOINLINE void sanity_check(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, bool inplace)
static void sanity_check_axes(size_t ndim, const shape_t &axes)
{
auto ndim = shape.size();
if (ndim<1) throw std::runtime_error("ndim must be >= 1");
if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim))
throw std::runtime_error("stride dimension mismatch");
if (inplace && (stride_in!=stride_out))
throw std::runtime_error("stride mismatch");
}
static MRUTIL_NOINLINE void sanity_check(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, bool inplace,
const shape_t &axes)
{
sanity_check(shape, stride_in, stride_out, inplace);
auto ndim = shape.size();
shape_t tmp(ndim,0);
if (axes.empty()) throw std::invalid_argument("no axes specified");
for (auto ax : axes)
{
if (ax>=ndim) throw std::invalid_argument("bad axis number");
......@@ -247,12 +233,30 @@ struct util // hack to avoid duplicate symbols
}
}
static MRUTIL_NOINLINE void sanity_check(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, bool inplace,
size_t axis)
static MRUTIL_NOINLINE void sanity_check_onetype(const fmav_info &a1,
const fmav_info &a2, bool inplace, const shape_t &axes)
{
sanity_check(shape, stride_in, stride_out, inplace);
if (axis>=shape.size()) throw std::invalid_argument("bad axis number");
sanity_check_axes(a1.ndim(), axes);
MR_assert(a1.conformable(a2), "array sizes are not conformable");
if (inplace) MR_assert(a1.stride()==a2.stride(), "stride mismatch");
}
static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
const fmav_info &ar, const shape_t &axes)
{
sanity_check_axes(ac.ndim(), axes);
MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
for (size_t i=0; i<ac.ndim(); ++i)
MR_assert(ac.shape(i)== (i==axes.back()) ? (ar.shape(i)/2+1) : ar.shape(i),
"axis length mismatch");
}
static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
const fmav_info &ar, const size_t axis)
{
if (axis>=ac.ndim()) throw std::invalid_argument("bad axis number");
MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
for (size_t i=0; i<ac.ndim(); ++i)
MR_assert(ac.shape(i)== (i==axis) ? (ar.shape(i)/2+1) : ar.shape(i),
"axis length mismatch");
}
#ifdef MRUTIL_NO_THREADING
......@@ -2571,7 +2575,7 @@ MRUTIL_NOINLINE void general_nd(const cfmav<T> &in, const fmav<T> &out,
[&](Scheduler &sched) {
constexpr auto vlen = VLEN<T0>::val;
auto storage = alloc_tmp<T,T0>(in.shape(), len);
const auto &tin(iax==0? in : cfmav<T>(out));
const auto &tin(iax==0? in : out);
multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num());
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
......@@ -2814,177 +2818,132 @@ struct ExecR2R
}
};
template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
const stride_t &stride_out, const shape_t &axes, bool forward,
const std::complex<T> *data_in, std::complex<T> *data_out, T fct,
size_t nthreads=1)
template<typename T> void c2c(const cfmav<std::complex<T>> &in,
const fmav<std::complex<T>> &out, const shape_t &axes, bool forward,
T fct, size_t nthreads=1)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
const cfmav<Cmplx<T>> ain(reinterpret_cast<const Cmplx<T> *>(data_in), shape, stride_in);
const fmav<Cmplx<T>> aout(reinterpret_cast<Cmplx<T> *>(data_out), shape, stride_out);
general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward});
util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
if (in.size()==0) return;
cfmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.data()), out);
general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward});
}
template<typename T> void dct(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
template<typename T> void dct(const cfmav<T> &in, const fmav<T> &out,
const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
{
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
if (in.size()==0) return;
const ExecDcst exec{ortho, type, true};
if (type==1)
general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dct1<T>>(in, out, axes, fct, nthreads, exec);
else if (type==4)
general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
else
general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
}
template<typename T> void dst(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
template<typename T> void dst(const cfmav<T> &in, const fmav<T> &out,
const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
{
if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
const ExecDcst exec{ortho, type, false};
if (type==1)
general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dst1<T>>(in, out, axes, fct, nthreads, exec);
else if (type==4)
general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
else
general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
}
template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const T *data_in, std::complex<T> *data_out, T fct,
template<typename T> void r2c(const cfmav<T> &in,
const fmav<std::complex<T>> &out, size_t axis, bool forward, T fct,
size_t nthreads=1)
{
if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axis);
const cfmav<T> ain(data_in, shape_in, stride_in);
shape_t shape_out(shape_in);
shape_out[axis] = shape_in[axis]/2 + 1;
const fmav<Cmplx<T>> aout(reinterpret_cast<Cmplx<T> *>(data_out), shape_out, stride_out);
general_r2c(ain, aout, axis, forward, fct, nthreads);
util::sanity_check_cr(out, in, axis);
if (in.size()==0) return;
fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.data()), out);
general_r2c(in, out2, axis, forward, fct, nthreads);
}
template<typename T> void r2c(const shape_t &shape_in,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool forward, const T *data_in, std::complex<T> *data_out, T fct,
size_t nthreads=1)
template<typename T> void r2c(const cfmav<T> &in,
const fmav<std::complex<T>> &out, const shape_t &axes,
bool forward, T fct, size_t nthreads=1)
{
if (util::prod(shape_in)==0) return;
util::sanity_check(shape_in, stride_in, stride_out, false, axes);
r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out,
fct, nthreads);
util::sanity_check_cr(out, in, axes);
if (in.size()==0) return;
r2c(in, out, axes.back(), forward, fct, nthreads);
if (axes.size()==1) return;
shape_t shape_out(shape_in);
shape_out[axes.back()] = shape_in[axes.back()]/2 + 1;
auto newaxes = shape_t{axes.begin(), --axes.end()};
c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out,
T(1), nthreads);
c2c(out, out, newaxes, forward, T(1), nthreads);
}
template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, size_t axis,
bool forward, const std::complex<T> *data_in, T *data_out, T fct,
size_t nthreads=1)
template<typename T> void c2r(const cfmav<std::complex<T>> &in,
const fmav<T> &out, size_t axis, bool forward, T fct, size_t nthreads=1)
{
if (util::prod(shape_out)==0) return;
util::sanity_check(shape_out, stride_in, stride_out, false, axis);
shape_t shape_in(shape_out);
shape_in[axis] = shape_out[axis]/2 + 1;
const cfmav<Cmplx<T>> ain(reinterpret_cast<const Cmplx<T> *>(data_in), shape_in, stride_in);
const fmav<T> aout(data_out, shape_out, stride_out);
general_c2r(ain, aout, axis, forward, fct, nthreads);
util::sanity_check_cr(in, out, axis);
if (in.size()==0) return;
cfmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
general_c2r(in2, out, axis, forward, fct, nthreads);
}
template<typename T> void c2r(const shape_t &shape_out,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool forward, const std::complex<T> *data_in, T *data_out, T fct,
template<typename T> void c2r(const cfmav<std::complex<T>> &in,
const fmav<T> &out, const shape_t &axes, bool forward, T fct,
size_t nthreads=1)
{
if (util::prod(shape_out)==0) return;
if (axes.size()==1)
return c2r(shape_out, stride_in, stride_out, axes[0], forward,
data_in, data_out, fct, nthreads);
util::sanity_check(shape_out, stride_in, stride_out, false, axes);
auto shape_in = shape_out;
shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;
auto nval = util::prod(shape_in);
stride_t stride_inter(shape_in.size());
stride_inter.back() = 1;
for (int i=int(shape_in.size())-2; i>=0; --i)
stride_inter[size_t(i)] =
stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]);
aligned_array<std::complex<T>> tmp(nval);
return c2r(in, out, axes[0], forward, fct, nthreads);
util::sanity_check_cr(in, out, axes);
if (in.size()==0) return;
aligned_array<std::complex<T>> tmp(in.size());
fmav<std::complex<T>> atmp(tmp.data(), in);
auto newaxes = shape_t({axes.begin(), --axes.end()});
c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(),
T(1), nthreads);
c2r(shape_out, stride_inter, stride_out, axes.back(), forward,
tmp.data(), data_out, fct, nthreads);
c2c(in, atmp, newaxes, forward, T(1), nthreads);
c2r(atmp, out, axes.back(), forward, fct, nthreads);
}
template<typename T> void r2r_fftpack(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct,
size_t nthreads=1)
template<typename T> void r2r_fftpack(const cfmav<T> &in,
const fmav<T> &out, const shape_t &axes, bool real2hermitian, bool forward,
T fct, size_t nthreads=1)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads,
util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
if (in.size()==0) return;
general_nd<pocketfft_r<T>>(in, out, axes, fct, nthreads,
ExecR2R{real2hermitian, forward});
}
template<typename T> void r2r_separable_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, T *data_out, T fct, size_t nthreads=1)
template<typename T> void r2r_separable_hartley(const cfmav<T> &in,
const fmav<T> &out, const shape_t &axes, T fct, size_t nthreads=1)
{
if (util::prod(shape)==0) return;
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
const cfmav<T> ain(data_in, shape, stride_in);
const fmav<T> aout(data_out, shape, stride_out);
general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},
util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
if (in.size()==0) return;
general_nd<pocketfft_r<T>>(in, out, axes, fct, nthreads, ExecHartley{},
false);
}
template<typename T> void r2r_genuine_hartley(const shape_t &shape,
const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
const T *data_in, T *data_out, T fct, size_t nthreads=1)
template<typename T> void r2r_genuine_hartley(const cfmav<T> &in,
const fmav<T> &out, const shape_t &axes, T fct, size_t nthreads=1)
{
if (util::prod(shape)==0) return;
if (axes.size()==1)
return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in,
data_out, fct, nthreads);
util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
shape_t tshp(shape);
return r2r_separable_hartley(in, out, axes, fct, nthreads);
util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
if (in.size()==0) return;
shape_t tshp(in.shape());
tshp[axes.back()] = tshp[axes.back()]/2+1;
aligned_array<std::complex<T>> tdata(util::prod(tshp));
stride_t tstride(shape.size());
tstride.back()=1;
for (size_t i=tstride.size()-1; i>0; --i)
tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]);
r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);
const cfmav<Cmplx<T>> atmp(reinterpret_cast<const Cmplx<T> *>(tdata.data()), tshp, tstride);
const fmav<T> aout(data_out, shape, stride_out);
auto tinfo = fmav_info(tshp);
aligned_array<std::complex<T>> tdata(tinfo.size());
fmav<std::complex<T>> atmp(tdata.data(), tinfo);
r2c(in, atmp, axes, true, fct, nthreads);
simple_iter iin(atmp);
rev_iter iout(aout, axes);
rev_iter iout(out, axes);
while(iin.remaining()>0)
{
auto v = atmp[iin.ofs()];
aout[iout.ofs()] = v.r+v.i;
aout[iout.rev_ofs()] = v.r-v.i;
out[iout.ofs()] = v.real()+v.imag();
out[iout.rev_ofs()] = v.real()-v.imag();
iin.advance(); iout.advance();
}
}
......
......@@ -77,7 +77,7 @@ class fmav_info
for (size_t i=0; i<ndim; ++i)
{
if (str[ndim-1-i]!=stride) return false;
stride *= shp[ndim-1-i];
stride *= ptrdiff_t(shp[ndim-1-i]);
}
return true;
}
......@@ -96,6 +96,8 @@ template<typename T> class cfmav: public fmav_info
: fmav_info(shp_, str_), d(const_cast<T *>(d_)) {}
cfmav(const T *d_, const shape_t &shp_)
: fmav_info(shp_), d(const_cast<T *>(d_)) {}
cfmav(const T* d_, const fmav_info &info)
: fmav_info(info), d(const_cast<T *>(d_)) {}
template<typename I> const T &operator[](I i) const
{ return d[i]; }
const T *data() const
......@@ -114,6 +116,8 @@ template<typename T> class fmav: public cfmav<T>
: parent(d_, shp_, str_) {}
fmav(T *d_, const shape_t &shp_)
: parent(d_, shp_) {}
fmav(T* d_, const fmav_info &info)
: parent(d_, info) {}
template<typename I> T &operator[](I i) const
{ return d[i]; }
using parent::shape;
......@@ -125,36 +129,40 @@ template<typename T> class fmav: public cfmav<T>
using parent::conformable;
};
template<typename T> using const_fmav = fmav<const T>;
template<typename T, size_t ndim> class mav
template<typename T, size_t ndim> class cmav
{
static_assert((ndim>0) && (ndim<3), "only supports 1D and 2D arrays");
private:
protected:
T *d;
array<size_t, ndim> shp;
array<ptrdiff_t, ndim> str;
public:
mav(T *d_, const array<size_t,ndim> &shp_,
cmav(const T *d_, const array<size_t,ndim> &shp_,
const array<ptrdiff_t,ndim> &str_)
: d(d_), shp(shp_), str(str_) {}
mav(T *d_, const array<size_t,ndim> &shp_)
: d(d_), shp(shp_)
: d(const_cast<T *>(d_)), shp(shp_), str(str_) {}
cmav(const T *d_, const array<size_t,ndim> &shp_)
: d(const_cast<T *>(d_)), shp(shp_)
{
str[ndim-1]=1;
for (size_t i=2; i<=ndim; ++i)
str[ndim-i] = str[ndim-i+1]*shp[ndim-i+1];
}
T &operator[](size_t i) const
operator cfmav<T>() const
{
return cfmav<T>(d, {shp.begin(), shp.end()}, {str.begin(), str.end()});
}
const T &operator[](size_t i) const
{ return operator()(i); }
T &operator()(size_t i) const
const T &operator()(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return d[str[0]*i];
}
T &operator()(size_t i, size_t j) const
const T &operator()(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return d[str[0]*i + str[1]*j];
......@@ -168,7 +176,7 @@ template<typename T, size_t ndim> class mav
return res;
}
ptrdiff_t stride(size_t i) const { return str[i]; }
T *data() const
const T *data() const
{ return d; }
bool last_contiguous() const
{ return (str[ndim-1]==1) || (str[ndim-1]==0); }
......@@ -182,6 +190,46 @@ template<typename T, size_t ndim> class mav
}
return true;
}
template<typename T2> bool conformable(const cmav<T2,ndim> &other) const
{ return shp==other.shp; }
};
template<typename T, size_t ndim> class mav: public cmav<T, ndim>
{
protected:
using parent = cmav<T, ndim>;
using parent::d;
using parent::shp;
using parent::str;
public:
mav(T *d_, const array<size_t,ndim> &shp_,
const array<ptrdiff_t,ndim> &str_)
: parent(d_, shp_, str_) {}
mav(T *d_, const array<size_t,ndim> &shp_)
: parent(d_, shp_) {}
operator fmav<T>() const
{
return fmav<T>(d, {shp.begin(), shp.end()}, {str.begin(), str.end()});
}
T &operator[](size_t i) const
{ return operator()(i); }
T &operator()(size_t i) const
{
static_assert(ndim==1, "ndim must be 1");
return d[str[0]*i];
}
T &operator()(size_t i, size_t j) const
{
static_assert(ndim==2, "ndim must be 2");
return d[str[0]*i + str[1]*j];
}
using parent::shape;
using parent::stride;
using parent::size;
T *data() const
{ return d; }
using parent::last_contiguous;
using parent::contiguous;
void fill(const T &val) const
{
// FIXME: special cases for contiguous arrays and/or zeroing?
......@@ -193,20 +241,9 @@ template<typename T, size_t ndim> class mav
for (size_t j=0; j<shp[1]; ++j)
d[str[0]*i + str[1]*j] = val;
}
template<typename T2> bool conformable(const mav<T2,ndim> &other) const
{ return shp==other.shp; }
using parent::conformable;
};
template<typename T, size_t ndim> using const_mav = mav<const T, ndim>;
template<typename T, size_t ndim> const_mav<T, ndim> cmav (const mav<T, ndim> &mav)
{ return const_mav<T, ndim>(mav.data(), mav.shape()); }
template<typename T, size_t ndim> const_mav<T, ndim> nullmav()
{
array<size_t,ndim> shp;
shp.fill(0);
return const_mav<T, ndim>(nullptr, shp);
}
}
using detail_mav::shape_t;
......@@ -214,11 +251,8 @@ using detail_mav::stride_t;
using detail_mav::fmav_info;
using detail_mav::fmav;
using detail_mav::cfmav;
using detail_mav::const_fmav;
using detail_mav::mav;
using detail_mav::const_mav;
using detail_mav::cmav;
using detail_mav::nullmav;
}
......
......@@ -76,7 +76,6 @@ template<typename T, size_t ndim> class tmpStorage
tmpStorage(const array<size_t,ndim> &shp)
: d(prod(shp)), mav_(d.data(), shp) {}
mav<T,ndim> &getMav() { return mav_; }
const_mav<T,ndim> getCmav() { return cmav(mav_); }
void fill(const T & val)
{ std::fill(d.begin(), d.end(), val); }
};
......@@ -86,7 +85,7 @@ template<typename T, size_t ndim> class tmpStorage
//
template<typename T> void complex2hartley
(const const_mav<complex<T>, 2> &grid, const mav<T,2> &grid2, size_t nthreads)
(const cmav<complex<T>, 2> &grid, const mav<T,2> &grid2, size_t nthreads)
{
checkShape(grid.shape(), grid2.shape());
size_t nu=grid.shape(0), nv=grid.shape(1);
......@@ -107,7 +106,7 @@ template<typename T> void complex2hartley
}
template<typename T> void hartley2complex
(const const_mav<T,2> &grid, const mav<complex<T>,2> &grid2, size_t nthreads)
(const cmav<T,2> &grid, const mav<complex<T>,2> &grid2, size_t nthreads)
{
checkShape(grid.shape(), grid2.shape());
size_t nu=grid.shape(0), nv=grid.shape(1);
......@@ -128,17 +127,13 @@ template<typename T> void hartley2complex
});
}
template<typename T> void hartley2_2D(const const_mav<T,2> &in,
template<typename T> void hartley2_2D(const cmav<T,2> &in,
const mav<T,2> &out, size_t nthreads)
{
checkShape(in.shape(), out.shape());
size_t nu=in.shape(0), nv=in.shape(1);
stride_t stri{in.stride(0), in.stride(1)};
stride_t stro{out.stride(0), out.stride(1)};
auto d_i = in.data();
auto ptmp = out.data();
r2r_separable_hartley({nu, nv}, stri, stro, {0,1}, d_i, ptmp, T(1),
nthreads);
r2r_separable_hartley(cfmav<T>(in), fmav<T>(out), {0,1}, T(1), nthreads);
execStatic((nu+1)/2-1, nthreads, 0, [&](Scheduler &sched)
{
while (auto rng=sched.getNext()) for(auto i=rng.lo+1; i<rng.hi+1; ++i)
......@@ -285,8 +280,8 @@ class Baselines
idx_t shift, mask;
public:
template<typename T> Baselines(const const_mav<T,2> &coord_,
const const_mav<T,1> &freq, bool negate_v=false)
template<typename T> Baselines(const cmav<T,2> &coord_,
const cmav<T,1> &freq, bool negate_v=false)
{
constexpr double speedOfLight = 299792458.;
MR_assert(coord_.shape(1)==3, "dimension mismatch");
......@@ -457,7 +452,7 @@ class GridderConfig
});
}
template<typename T> void grid2dirty(const const_mav<T,2> &grid,
template<typename T> void grid2dirty(const cmav<T,2> &grid,
const mav<T,2> &dirty) const
{
checkShape(grid.shape(), {nu,nv});
......@@ -471,13 +466,12 @@ class GridderConfig
(const mav<complex<T>,2> &grid, const mav<T,2> &dirty, T w) const
{
checkShape(grid.shape(), {nu,nv});
c2c({nu,nv},{grid.stride(0),grid.stride(1)},
{grid.stride(0), grid.stride(1)}, {0,1}, BACKWARD,
grid.data(), grid.data(), T(1), nthreads);
fmav<complex<T>> inout(grid);
c2c(inout, inout, {0,1}, BACKWARD, T(1), nthreads);
grid2dirty_post2(grid, dirty, w);
}
template<typename T> void dirty2grid_pre(const const_mav<T,2> &dirty,
template<typename T> void dirty2grid_pre(const cmav<T,2> &dirty,
const mav<T,2> &grid) const
{
checkShape(dirty.shape(), {nx_dirty, ny_dirty});
......@@ -502,7 +496,7 @@ class GridderConfig
}
});
}
template<typename T> void dirty2grid_pre2(const const_mav<T,2> &dirty,
template<typename T> void dirty2grid_pre2(const cmav<T,2> &d