Commit 07c29f71 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

faster convolutions

parent 54db3189
......@@ -22,7 +22,6 @@
namespace mr {
#if 0
namespace detail_fft {
using std::vector;
......@@ -39,7 +38,7 @@ template<typename T, typename T0> aligned_array<T> alloc_tmp_conv
template<typename Tplan, typename T, typename T0, typename Exec>
MRUTIL_NOINLINE void general_convolve(const fmav<T> &in, fmav<T> &out,
const size_t axis, const vector<T0> &kernel, size_t nthreads,
const Exec &exec, const bool allow_inplace=true)
const Exec &exec)
{
std::shared_ptr<Tplan> plan1, plan2;
......@@ -53,7 +52,7 @@ MRUTIL_NOINLINE void general_convolve(const fmav<T> &in, fmav<T> &out,
util::thread_count(nthreads, in, axis, native_simd<T0>::size()),
[&](Scheduler &sched) {
constexpr auto vlen = native_simd<T0>::size();
auto storage = alloc_tmp_conv<T,T0>(in, axis, l_max); //FIXME!
auto storage = alloc_tmp_conv<T,T0>(in, axis, l_max);
multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef MRUTIL_NO_SIMD
if (vlen>1)
......@@ -67,8 +66,7 @@ MRUTIL_NOINLINE void general_convolve(const fmav<T> &in, fmav<T> &out,
while (it.remaining()>0)
{
it.advance(1);
auto buf = allow_inplace && it.stride_out() == 1 ?
&out.vraw(it.oofs(0)) : reinterpret_cast<T *>(storage.data());
auto buf = reinterpret_cast<T *>(storage.data());
exec(it, in, out, buf, *plan1, *plan2, kernel);
}
}); // end of parallel region
......@@ -86,9 +84,7 @@ struct ExecConvR1
l_min = std::min(l_in, l_out);
copy_input(it, in, buf);
plan1.exec(buf, T0(1), true);
buf[0] *= kernel[0];
for (size_t i=1; i<l_min; ++i)
{ buf[2*i-1]*=kernel[i]; buf[2*i] *=kernel[i]; }
for (size_t i=0; i<l_min; ++i) buf[i]*=kernel[(i+1)/2];
for (size_t i=l_in; i<l_out; ++i) buf[i] = T(0);
plan2.exec(buf, T0(1), false);
copy_output(it, buf, out);
......@@ -98,21 +94,24 @@ struct ExecConvR1
template<typename T> void convolve_1d(const fmav<T> &in,
fmav<T> &out, size_t axis, const vector<T> &kernel, size_t nthreads=1)
{
// util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
MR_assert(axis<in.ndim(), "bad axis number");
MR_assert(in.ndim()==out.ndim(), "dimensionality mismatch");
if (in.data()==out.data())
MR_assert(in.strides()==out.strides(), "strides mismatch");
MR_assert(in.stride()==out.stride(), "strides mismatch");
for (size_t i=0; i<in.ndim(); ++i)
if (i!=axis)
MR_assert(in.shape(i)==out.shape(i), "shape mismatch");
MR_assert(!((in.shape(axis)&1) || (out.shape(axis)&1)),
"input and output axis lengths must be even");
if (in.size()==0) return;
general_convolve<pocketfft_r<T>>(in, out, axis, kernel, nthreads,
ExecConvR1());
}
}
#endif
using detail_fft::convolve_1d;
namespace detail_interpol_ng {
using namespace std;
......@@ -132,83 +131,54 @@ template<typename T> class Interpolator
void correct(mav<T,2> &arr, int spin)
{
T sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi});
tmp.apply([](T &v){v=0.;});
auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
fmav<T> ftmp0(tmp0);
for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) = arr(i,j);
// extend to second half
// FIXME: merge with loop above to avoid edundant memory reads.
mav<T,2> tmp({nphi,nphi0});
// copy and extend to second half
for (size_t j=0; j<nphi0; ++j)
tmp.v(0,j) = arr(0,j);
for (size_t i=1, i2=nphi0-1; i+1<ntheta0; ++i,--i2)
for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
{
if (j2>=nphi0) j2-=nphi0;
tmp0.v(i2,j) = sfct*tmp0(i,j2);
tmp.v(i,j2) = arr(i,j2);
tmp.v(i2,j) = sfct*tmp(i,j2);
}
// FFT to frequency domain on minimal grid
// one bad FFT axis
r2r_fftpack(ftmp0,ftmp0,{0,1},true,true,T(1./(nphi0*nphi0)),nthreads);
// correct amplitude at Nyquist frequency
for (size_t i=0; i<nphi0; ++i)
{
tmp0.v(i,nphi0-1)*=0.5;
tmp0.v(nphi0-1,i)*=0.5;
}
for (size_t j=0; j<nphi0; ++j)
tmp.v(ntheta0-1,j) = arr(ntheta0-1,j);
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (size_t i=0; i<nphi0; ++i)
for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
fmav<T> ftmp1(tmp1);
// zero-padded FFT in theta direction
// one bad FFT axis
r2r_fftpack(ftmp1,ftmp1,{0},false,false,T(1),nthreads);
auto tmp2=tmp.template subarray<2>({0,0},{ntheta, nphi});
fmav<T> ftmp2(tmp2);
for (auto &f:fct) f/=nphi0;
fmav<T> ftmp(tmp);
fmav<T> ftmp0(tmp.template subarray<2>({0,0},{nphi0, nphi0}));
convolve_1d(ftmp0, ftmp, 0, fct, nthreads);
fmav<T> ftmp2(tmp.template subarray<2>({0,0},{ntheta, nphi0}));
fmav<T> farr(arr);
// zero-padded FFT in phi direction
r2r_fftpack(ftmp2,farr,{1},false,false,T(1),nthreads);
convolve_1d(ftmp2, farr, 1, fct, nthreads);
}
void decorrect(mav<T,2> &arr, int spin)
{
T sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi});
fmav<T> ftmp(tmp);
for (size_t i=0; i<ntheta; ++i)
for (size_t j=0; j<nphi; ++j)
tmp.v(i,j) = arr(i,j);
mav<T,2> tmp({nphi,nphi0});
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (auto &f:fct) f/=nphi0;
fmav<T> farr(arr);
fmav<T> ftmp2(tmp.template subarray<2>({0,0},{ntheta, nphi0}));
convolve_1d(farr, ftmp2, 1, fct, nthreads);
// extend to second half
for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2)
for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
{
if (j2>=nphi) j2-=nphi;
if (j2>=nphi0) j2-=nphi0;
tmp.v(i2,j) = sfct*tmp(i,j2);
}
r2r_fftpack(ftmp,ftmp,{1},true,true,T(1),nthreads);
auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
fmav<T> ftmp1(tmp1);
r2r_fftpack(ftmp1,ftmp1,{0},true,true,T(1),nthreads);
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
fmav<T> ftmp0(tmp0);
for (size_t i=0; i<nphi0; ++i)
for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
// FFT to (theta, phi) domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,T(1./(nphi0*nphi0)),nthreads);
fmav<T> ftmp(tmp);
fmav<T> ftmp0(tmp.template subarray<2>({0,0},{nphi0, nphi0}));
convolve_1d(ftmp, ftmp0, 0, fct, nthreads);
for (size_t j=0; j<nphi0; ++j)
{
tmp0.v(0,j)*=0.5;
tmp0.v(ntheta0-1,j)*=0.5;
}
for (size_t i=0; i<ntheta0; ++i)
arr.v(0,j) = 0.5*tmp(0,j);
for (size_t i=1; i+1<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
arr.v(i,j) = tmp0(i,j);
arr.v(i,j) = tmp(i,j);
for (size_t j=0; j<nphi0; ++j)
arr.v(ntheta0-1,j) = 0.5*tmp(ntheta0-1,j);
}
vector<size_t> getIdx(const mav<T,2> &ptg) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment