Commit 76a024d9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'vectorization' into 'master'

Vectorization

See merge request mtr/cxxbase!8
parents 4b36a29f 1da6d79c
Pipeline #74111 passed with stages
in 12 minutes and 11 seconds
......@@ -162,7 +162,7 @@ class wigner_d_risbo_openmp
wigner_d_risbo_openmp(size_t lmax, double ang)
: p(sin(ang/2)), q(cos(ang/2)), sqt(2*lmax+1),
d({lmax+1,2*lmax+1}), dd({lmax+1,2*lmax+1}), n(-1)
{ for (size_t m=0; m<sqt.size(); ++m) sqt[m] = sqrt(double(m)); }
{ for (size_t m=0; m<sqt.size(); ++m) sqt[m] = std::sqrt(double(m)); }
const mav<double,2> &recurse()
{
......
import pyinterpol_ng
import numpy as np
import pysharp
import time
import matplotlib.pyplot as plt
np.random.seed(48)
......@@ -39,11 +37,11 @@ def convolve(alm1, alm2, lmax):
return job.map2alm(map)[0]*np.sqrt(4*np.pi)
lmax=60
lmax=1024
kmax=13
ncomp=1
separate=False
nptg = 1000000
separate=True
nptg = 50000000
epsilon = 1e-4
ofactor = 1.5
nthreads = 0 # use as many threads as available
......@@ -61,57 +59,70 @@ blm = random_alm(lmax, kmax, ncomp)
t0=time.time()
# build interpolator object for slm and blm
foo = pyinterpol_ng.PyInterpolator(slm,blm,separate,lmax, kmax, epsilon=epsilon, ofactor=ofactor, nthreads=nthreads)
print("setup time: ",time.time()-t0)
print("support:",foo.support())
t1 = time.time()-t0
print("Convolving sky and beam with lmax=mmax={}, kmax={}".format(lmax,kmax))
print("Interpolation taking place with a maximum error of {}\n"
"and an oversampling factor of {}".format(epsilon, ofactor))
supp = foo.support()
print("(resulting in a kernel support size of {}x{})".format(supp,supp))
if ncomp == 1:
print("One component")
else:
print("{} components, which are {}coadded".format(ncomp, "not " if separate else ""))
print("\nDouble precision convolution/interpolation:")
print("preparation of interpolation grid: {}s".format(t1))
t0=time.time()
nth = lmax+1
nph = 2*lmax+1
# compute a convolved map at a fixed psi and compare it to a map convolved
# "by hand"
ptg = np.zeros((nth,nph,3))
ptg[:,:,0] = (np.pi*(0.5+np.arange(nth))/nth).reshape((-1,1))
ptg[:,:,1] = (2*np.pi*(0.5+np.arange(nph))/nph).reshape((1,-1))
ptg[:,:,2] = np.pi*0.2
t0=time.time()
# do the actual interpolation
bar=foo.interpol(ptg.reshape((-1,3))).reshape((nth,nph,ncomp2))
print("interpolation time: ", time.time()-t0)
plt.subplot(2,2,1)
plt.imshow(bar[:,:,0])
bar2 = np.zeros((nth,nph))
blmfull = np.zeros(slm.shape)+0j
blmfull[0:blm.shape[0],:] = blm
for ith in range(nth):
rbeamth=pyinterpol_ng.rotate_alm(blmfull[:,0], lmax, ptg[ith,0,2],ptg[ith,0,0],0)
for iph in range(nph):
rbeam=pyinterpol_ng.rotate_alm(rbeamth, lmax, 0, 0, ptg[ith,iph,1])
bar2[ith,iph] = convolve(slm[:,0], rbeam, lmax).real
plt.subplot(2,2,2)
plt.imshow(bar2)
plt.subplot(2,2,3)
plt.imshow(bar2-bar[:,:,0])
plt.show()
ptg=np.random.uniform(0.,1.,3*nptg).reshape(nptg,3)
ptg[:,0]*=np.pi
ptg[:,1]*=2*np.pi
ptg[:,2]*=2*np.pi
#foo = pyinterpol_ng.PyInterpolator(slm,blm,separate,lmax, kmax, epsilon=1e-6, nthreads=2)
t0=time.time()
bar=foo.interpol(ptg)
del foo
print("interpolation time: ", time.time()-t0)
print("Interpolating {} random angle triplets: {}s".format(nptg, time.time() -t0))
t0=time.time()
fake = np.random.uniform(0.,1., (ptg.shape[0],ncomp2))
foo2 = pyinterpol_ng.PyInterpolator(lmax, kmax, ncomp2, epsilon=epsilon, ofactor=ofactor, nthreads=nthreads)
t0=time.time()
foo2.deinterpol(ptg.reshape((-1,3)), fake)
print("deinterpolation time: ", time.time()-t0)
print("Adjoint interpolation: {}s".format(time.time() -t0))
t0=time.time()
bla=foo2.getSlm(blm)
print("getSlm time: ", time.time()-t0)
del foo2
print("Computing s_lm: {}s".format(time.time() -t0))
v1 = np.sum([myalmdot(slm[:,i], bla[:,i] , lmax, lmax, 0) for i in range(ncomp)])
v2 = np.sum([np.vdot(fake[:,i],bar[:,i]) for i in range(ncomp2)])
print(v1/v2-1.)
print("Adjointness error: {}".format(v1/v2-1.))
# build interpolator object for slm and blm
t0=time.time()
foo_f = pyinterpol_ng.PyInterpolator_f(slm.astype(np.complex64),blm.astype(np.complex64),separate,lmax, kmax, epsilon=epsilon, ofactor=ofactor, nthreads=nthreads)
print("\nSingle precision convolution/interpolation:")
print("preparation of interpolation grid: {}s".format(time.time()-t0))
ptgf = ptg.astype(np.float32)
del ptg
fake_f = fake.astype(np.float32)
del fake
t0=time.time()
bar_f=foo_f.interpol(ptgf)
del foo_f
print("Interpolating {} random angle triplets: {}s".format(nptg, time.time() -t0))
foo2_f = pyinterpol_ng.PyInterpolator_f(lmax, kmax, ncomp2, epsilon=epsilon, ofactor=ofactor, nthreads=nthreads)
t0=time.time()
foo2_f.deinterpol(ptgf.reshape((-1,3)), fake_f)
print("Adjoint interpolation: {}s".format(time.time() -t0))
t0=time.time()
bla_f=foo2_f.getSlm(blm.astype(np.complex64))
del foo2_f
print("Computing s_lm: {}s".format(time.time() -t0))
v1 = np.sum([myalmdot(slm[:,i], bla_f[:,i] , lmax, lmax, 0) for i in range(ncomp)])
v2 = np.sum([np.vdot(fake_f[:,i],bar_f[:,i]) for i in range(ncomp2)])
print("Adjointness error: {}".format(v1/v2-1.))
......@@ -6,6 +6,9 @@
#ifndef MRUTIL_INTERPOL_NG_H
#define MRUTIL_INTERPOL_NG_H
#define SIMD_INTERPOL
#define SPECIAL_CASING
#include <vector>
#include <complex>
#include <cmath>
......@@ -13,6 +16,7 @@
#include "mr_util/math/gl_integrator.h"
#include "mr_util/math/es_kernel.h"
#include "mr_util/infra/mav.h"
#include "mr_util/infra/simd.h"
#include "mr_util/sharp/sharp.h"
#include "mr_util/sharp/sharp_almhelpers.h"
#include "mr_util/sharp/sharp_geomhelpers.h"
......@@ -126,6 +130,9 @@ template<typename T> class Interpolator
size_t supp;
ES_Kernel kernel;
size_t ncomp;
#ifdef SIMD_INTERPOL
mav<native_simd<T>,4> scube;
#endif
mav<T,4> cube; // the data cube (theta, phi, 2*mbeam+1, TGC)
void correct(mav<T,2> &arr, int spin)
......@@ -145,23 +152,25 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j)
tmp.v(ntheta0-1,j) = arr(ntheta0-1,j);
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (auto &f:fct) f/=nphi0;
vector<T> k2(fct.size());
for (size_t i=0; i<fct.size(); ++i) k2[i] = T(fct[i]/nphi0);
fmav<T> ftmp(tmp);
fmav<T> ftmp0(tmp.template subarray<2>({0,0},{nphi0, nphi0}));
convolve_1d(ftmp0, ftmp, 0, fct, nthreads);
convolve_1d(ftmp0, ftmp, 0, k2, nthreads);
fmav<T> ftmp2(tmp.template subarray<2>({0,0},{ntheta, nphi0}));
fmav<T> farr(arr);
convolve_1d(ftmp2, farr, 1, fct, nthreads);
convolve_1d(ftmp2, farr, 1, k2, nthreads);
}
void decorrect(mav<T,2> &arr, int spin)
{
T sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi0});
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (auto &f:fct) f/=nphi0;
vector<T> k2(fct.size());
for (size_t i=0; i<fct.size(); ++i) k2[i] = T(fct[i]/nphi0);
fmav<T> farr(arr);
fmav<T> ftmp2(tmp.template subarray<2>({0,0},{ntheta, nphi0}));
convolve_1d(farr, ftmp2, 1, fct, nthreads);
convolve_1d(farr, ftmp2, 1, k2, nthreads);
// extend to second half
for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
......@@ -171,14 +180,14 @@ template<typename T> class Interpolator
}
fmav<T> ftmp(tmp);
fmav<T> ftmp0(tmp.template subarray<2>({0,0},{nphi0, nphi0}));
convolve_1d(ftmp, ftmp0, 0, fct, nthreads);
convolve_1d(ftmp, ftmp0, 0, k2, nthreads);
for (size_t j=0; j<nphi0; ++j)
arr.v(0,j) = 0.5*tmp(0,j);
arr.v(0,j) = T(0.5)*tmp(0,j);
for (size_t i=1; i+1<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
arr.v(i,j) = tmp(i,j);
for (size_t j=0; j<nphi0; ++j)
arr.v(ntheta0-1,j) = 0.5*tmp(ntheta0-1,j);
arr.v(ntheta0-1,j) = T(0.5)*tmp(ntheta0-1,j);
}
vector<size_t> getIdx(const mav<T,2> &ptg) const
......@@ -218,7 +227,12 @@ template<typename T> class Interpolator
supp(ES_Kernel::get_supp(epsilon, ofactor)),
kernel(supp, ofactor, nthreads),
ncomp(separate ? slm.size() : 1),
cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1, ncomp})
#ifdef SIMD_INTERPOL
scube({ntheta+2*supp, nphi+2*supp, ncomp, (2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size()}),
cube(reinterpret_cast<T *>(scube.vdata()),{ntheta+2*supp, nphi+2*supp, ncomp, ((2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size())*native_simd<T>::size()},true)
#else
cube({ntheta+2*supp, nphi+2*supp, ncomp, 2*kmax+1})
#endif
{
MR_assert((ncomp==1)||(ncomp==3), "currently only 1 or 3 components allowed");
MR_assert(slm.size()==blm.size(), "inconsistent slm and blm vectors");
......@@ -237,7 +251,7 @@ template<typename T> class Interpolator
vector<T>lnorm(lmax+1);
for (size_t i=0; i<=lmax; ++i)
lnorm[i]=std::sqrt(4*pi/(2*i+1.));
lnorm[i]=T(std::sqrt(4*pi/(2*i+1.)));
for (size_t icomp=0; icomp<ncomp; ++icomp)
{
......@@ -245,15 +259,15 @@ template<typename T> class Interpolator
for (size_t l=m; l<=lmax; ++l)
{
if (separate)
a1(l,m) = slm[icomp](l,m)*blm[icomp](l,0).real()*T(lnorm[l]);
a1(l,m) = slm[icomp](l,m)*blm[icomp](l,0).real()*lnorm[l];
else
{
a1(l,m) = 0;
for (size_t j=0; j<slm.size(); ++j)
a1(l,m) += slm[j](l,m)*blm[j](l,0).real()*T(lnorm[l]);
a1(l,m) += slm[j](l,m)*blm[j](l,0).real()*lnorm[l];
}
}
auto m1 = cube.template subarray<2>({supp,supp,0,icomp},{ntheta,nphi,0,0});
auto m1 = cube.template subarray<2>({supp,supp,icomp,0},{ntheta,nphi,0,0});
sharp_alm2map(a1.Alms().data(), m1.vdata(), *ginfo, *ainfo, 0, nthreads);
correct(m1,0);
......@@ -268,7 +282,7 @@ template<typename T> class Interpolator
{
if (separate)
{
auto tmp = blm[icomp](l,k)*T(-2*lnorm[l]);
auto tmp = blm[icomp](l,k)*(-2*lnorm[l]);
a1(l,m) = slm[icomp](l,m)*tmp.real();
a2(l,m) = slm[icomp](l,m)*tmp.imag();
}
......@@ -277,15 +291,15 @@ template<typename T> class Interpolator
a1(l,m) = a2(l,m) = 0;
for (size_t j=0; j<slm.size(); ++j)
{
auto tmp = blm[j](l,k)*T(-2*lnorm[l]);
auto tmp = blm[j](l,k)*(-2*lnorm[l]);
a1(l,m) += slm[j](l,m)*tmp.real();
a2(l,m) += slm[j](l,m)*tmp.imag();
}
}
}
}
auto m1 = cube.template subarray<2>({supp,supp,2*k-1,icomp},{ntheta,nphi,0,0});
auto m2 = cube.template subarray<2>({supp,supp,2*k ,icomp},{ntheta,nphi,0,0});
auto m1 = cube.template subarray<2>({supp,supp,icomp,2*k-1},{ntheta,nphi,0,0});
auto m2 = cube.template subarray<2>({supp,supp,icomp,2*k },{ntheta,nphi,0,0});
sharp_alm2map_spin(k, a1.Alms().data(), a2.Alms().data(), m1.vdata(),
m2.vdata(), *ginfo, *ainfo, 0, nthreads);
correct(m1,k);
......@@ -296,23 +310,23 @@ template<typename T> class Interpolator
// fill border regions
for (size_t i=0; i<supp; ++i)
for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t k=0; k<cube.shape(3); ++k)
{
T fct = (((k+1)/2)&1) ? -1 : 1;
if (j2>=nphi) j2-=nphi;
for (size_t l=0; l<cube.shape(3); ++l)
for (size_t l=0; l<cube.shape(2); ++l)
{
cube.v(supp-1-i,j2+supp,k,l) = fct*cube(supp+1+i,j+supp,k,l);
cube.v(supp+ntheta+i,j2+supp,k,l) = fct*cube(supp+ntheta-2-i,j+supp,k,l);
cube.v(supp-1-i,j2+supp,l,k) = fct*cube(supp+1+i,j+supp,l,k);
cube.v(supp+ntheta+i,j2+supp,l,k) = fct*cube(supp+ntheta-2-i,j+supp,l,k);
}
}
for (size_t i=0; i<ntheta+2*supp; ++i)
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t l=0; l<cube.shape(3); ++l)
for (size_t k=0; k<cube.shape(3); ++k)
for (size_t l=0; l<cube.shape(2); ++l)
{
cube.v(i,j,k,l) = cube(i,j+nphi,k,l);
cube.v(i,j+nphi+supp,k,l) = cube(i,j+supp,k,l);
cube.v(i,j,l,k) = cube(i,j+nphi,l,k);
cube.v(i,j+nphi+supp,l,k) = cube(i,j+supp,l,k);
}
}
......@@ -329,15 +343,72 @@ template<typename T> class Interpolator
supp(ES_Kernel::get_supp(epsilon, ofactor)),
kernel(supp, ofactor, nthreads),
ncomp(ncomp_),
cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1, ncomp_})
#ifdef SIMD_INTERPOL
scube({ntheta+2*supp, nphi+2*supp, ncomp, (2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size()}),
cube(reinterpret_cast<T *>(scube.vdata()),{ntheta+2*supp, nphi+2*supp, ncomp, ((2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size())*native_simd<T>::size()},true)
#else
cube({ntheta+2*supp, nphi+2*supp, ncomp, 2*kmax+1})
#endif
{
MR_assert((ncomp==1)||(ncomp==3), "currently only 1 or 3 components allowed");
MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
cube.apply([](T &v){v=0.;});
}
#ifdef SIMD_INTERPOL
template<size_t nv, size_t nc> void interpol_help0(const T * MRUTIL_RESTRICT wt,
const T * MRUTIL_RESTRICT wp, const native_simd<T> * MRUTIL_RESTRICT p, size_t d0, size_t d1, const native_simd<T> * MRUTIL_RESTRICT psiarr2, mav<T,2> &res, size_t idx) const
{
array<native_simd<T>,nc> vv;
for (auto &vvv:vv) vvv=0;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
array<native_simd<T>,nc> tvv;
for (auto &vvv:tvv) vvv=0;
for (size_t l=0; l<nv; ++l)
for (size_t c=0; c<nc; ++c)
tvv[c] += p1[l+nv*c]*psiarr2[l];
for (size_t c=0; c<nc; ++c)
vv[c] += (wtj*wp[k])*tvv[c];
}
}
for (size_t c=0; c<nc; ++c)
res.v(idx,c) = reduce(vv[c], std::plus<>());
}
template<size_t nv, size_t nc> void deinterpol_help0(const T * MRUTIL_RESTRICT wt,
const T * MRUTIL_RESTRICT wp, native_simd<T> * MRUTIL_RESTRICT p, size_t d0, size_t d1, const native_simd<T> * MRUTIL_RESTRICT psiarr2, const mav<T,2> &data, size_t idx) const
{
array<native_simd<T>,nc> vv;
for (size_t i=0; i<nc; ++i) vv[i] = data(idx,i);
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l)
{
auto tmp = wtjwpk*psiarr2[l];
for (size_t c=0; c<nc; ++c)
p1[l+nv*c] += vv[c]*tmp;
}
}
}
}
#endif
void interpol (const mav<T,2> &ptg, mav<T,2> &res) const
{
#ifdef SIMD_INTERPOL
constexpr size_t vl=native_simd<T>::size();
MR_assert(scube.stride(3)==1, "bad stride");
size_t nv=scube.shape(3);
#endif
MR_assert(!adjoint, "cannot be called in adjoint mode");
MR_assert(ptg.shape(0)==res.shape(0), "dimension mismatch");
MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
......@@ -350,6 +421,10 @@ template<typename T> class Interpolator
{
vector<T> wt(supp), wp(supp);
vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
for (auto &v:psiarr2) v=0;
#endif
while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
{
size_t i=idx[ind];
......@@ -373,30 +448,139 @@ template<typename T> class Interpolator
cnpsi=cnpsi*cpsi - snpsi*spsi;
snpsi=tmp;
}
#ifdef SIMD_INTERPOL
memcpy(reinterpret_cast<T *>(psiarr2.data()), psiarr.data(),
(2*kmax+1)*sizeof(T));
#endif
if (ncomp==1)
{
#ifndef SIMD_INTERPOL
T vv=0;
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
for (size_t l=0; l<2*kmax+1; ++l)
vv += cube(i0+j,i1+k,l,0)*wt[j]*wp[k]*psiarr[l];
vv += cube(i0+j,i1+k,0,l)*wt[j]*wp[k]*psiarr[l];
res.v(i,0) = vv;
#else
const native_simd<T> *p=&scube(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
switch (nv)
{
#ifdef SPECIAL_CASING
case 1:
interpol_help0<1,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 2:
interpol_help0<2,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 3:
interpol_help0<3,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 4:
interpol_help0<4,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 5:
interpol_help0<5,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 6:
interpol_help0<6,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 7:
interpol_help0<7,1>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
#endif
default:
{
native_simd<T> vv=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> tvv=0;
for (size_t l=0; l<nv; ++l)
tvv += p1[l]*psiarr2[l];
vv += wtj*wp[k]*tvv;
}
}
res.v(i,0) = reduce(vv, std::plus<>());
}
}
#endif
}
else // ncomp==3
{
#ifndef SIMD_INTERPOL
T v0=0., v1=0., v2=0.;
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
for (size_t l=0; l<2*kmax+1; ++l)
{
auto tmp = wt[j]*wp[k]*psiarr[l];
v0 += cube(i0+j,i1+k,l,0)*tmp;
v1 += cube(i0+j,i1+k,l,1)*tmp;
v2 += cube(i0+j,i1+k,l,2)*tmp;
v0 += cube(i0+j,i1+k,0,l)*tmp;
v1 += cube(i0+j,i1+k,1,l)*tmp;
v2 += cube(i0+j,i1+k,2,l)*tmp;
}
res.v(i,0) = v0;
res.v(i,1) = v1;
res.v(i,2) = v2;
#else
const native_simd<T> *p=&scube(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
switch (nv)
{
#ifdef SPECIAL_CASING
case 1:
interpol_help0<1,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 2:
interpol_help0<2,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 3:
interpol_help0<3,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 4:
interpol_help0<4,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 5:
interpol_help0<5,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 6:
interpol_help0<6,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
case 7:
interpol_help0<7,3>(wt.data(), wp.data(), p, d0, d1, psiarr2.data(), res, i);
break;
#endif
default:
{
ptrdiff_t d2 = scube.stride(2);
native_simd<T> v0=0., v1=0., v2=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l)
{
auto tmp = wtjwpk*psiarr2[l];
v0 += p1[l]*tmp;
v1 += p1[l+d2]*tmp;
v2 += p1[l+2*d2]*tmp;
}
}
}
res.v(i,0) = reduce(v0, std::plus<>());
res.v(i,1) = reduce(v1, std::plus<>());
res.v(i,2) = reduce(v2, std::plus<>());
}
}
#endif
}
}
});
......@@ -407,6 +591,12 @@ template<typename T> class Interpolator
void deinterpol (const mav<T,2> &ptg, const mav<T,2> &data)
{
#ifdef SIMD_INTERPOL
constexpr size_t vl=native_simd<T>::size();
MR_assert(scube.stride(3)==1, "bad stride");
MR_assert(scube.stride(2)==ptrdiff_t(scube.shape(3)), "bad stride");
size_t nv=scube.shape(3);
#endif
MR_assert(adjoint, "can only be called in adjoint mode");
MR_assert(ptg.shape(0)==data.shape(0), "dimension mismatch");
MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
......@@ -426,14 +616,18 @@ template<typename T> class Interpolator
size_t b_theta=99999999999999, b_phi=9999999999999999;
vector<T> wt(supp), wp(supp);
vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
for (auto &v:psiarr2) v=0;
#endif
while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
{
size_t i=idx[ind];
T f0=0.5*supp+ptg(i,0)*xdtheta;