Commit af6ffa43 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add vectorized interpolation and more

parent ed1b7cf7
Pipeline #73819 passed with stages
in 12 minutes and 8 seconds
......@@ -162,7 +162,7 @@ class wigner_d_risbo_openmp
wigner_d_risbo_openmp(size_t lmax, double ang)
: p(sin(ang/2)), q(cos(ang/2)), sqt(2*lmax+1),
d({lmax+1,2*lmax+1}), dd({lmax+1,2*lmax+1}), n(-1)
{ for (size_t m=0; m<sqt.size(); ++m) sqt[m] = sqrt(double(m)); }
{ for (size_t m=0; m<sqt.size(); ++m) sqt[m] = std::sqrt(double(m)); }
const mav<double,2> &recurse()
{
......
......@@ -6,6 +6,8 @@
#ifndef MRUTIL_INTERPOL_NG_H
#define MRUTIL_INTERPOL_NG_H
#define SIMD_INTERPOL
#include <vector>
#include <complex>
#include <cmath>
......@@ -13,6 +15,7 @@
#include "mr_util/math/gl_integrator.h"
#include "mr_util/math/es_kernel.h"
#include "mr_util/infra/mav.h"
#include "mr_util/infra/simd.h"
#include "mr_util/sharp/sharp.h"
#include "mr_util/sharp/sharp_almhelpers.h"
#include "mr_util/sharp/sharp_geomhelpers.h"
......@@ -126,6 +129,9 @@ template<typename T> class Interpolator
size_t supp;
ES_Kernel kernel;
size_t ncomp;
#ifdef SIMD_INTERPOL
mav<native_simd<T>,4> scube;
#endif
mav<T,4> cube; // the data cube (theta, phi, 2*mbeam+1, TGC)
void correct(mav<T,2> &arr, int spin)
......@@ -145,23 +151,25 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j)
tmp.v(ntheta0-1,j) = arr(ntheta0-1,j);
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (auto &f:fct) f/=nphi0;
vector<T> k2(fct.size());
for (size_t i=0; i<fct.size(); ++i) k2[i] = T(fct[i]/nphi0);
fmav<T> ftmp(tmp);
fmav<T> ftmp0(tmp.template subarray<2>({0,0},{nphi0, nphi0}));
convolve_1d(ftmp0, ftmp, 0, fct, nthreads);
convolve_1d(ftmp0, ftmp, 0, k2, nthreads);
fmav<T> ftmp2(tmp.template subarray<2>({0,0},{ntheta, nphi0}));
fmav<T> farr(arr);
convolve_1d(ftmp2, farr, 1, fct, nthreads);
convolve_1d(ftmp2, farr, 1, k2, nthreads);
}
void decorrect(mav<T,2> &arr, int spin)
{
T sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi0});
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (auto &f:fct) f/=nphi0;
vector<T> k2(fct.size());
for (size_t i=0; i<fct.size(); ++i) k2[i] = T(fct[i]/nphi0);
fmav<T> farr(arr);
fmav<T> ftmp2(tmp.template subarray<2>({0,0},{ntheta, nphi0}));
convolve_1d(farr, ftmp2, 1, fct, nthreads);
convolve_1d(farr, ftmp2, 1, k2, nthreads);
// extend to second half
for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
......@@ -171,14 +179,14 @@ template<typename T> class Interpolator
}
fmav<T> ftmp(tmp);
fmav<T> ftmp0(tmp.template subarray<2>({0,0},{nphi0, nphi0}));
convolve_1d(ftmp, ftmp0, 0, fct, nthreads);
convolve_1d(ftmp, ftmp0, 0, k2, nthreads);
for (size_t j=0; j<nphi0; ++j)
arr.v(0,j) = 0.5*tmp(0,j);
arr.v(0,j) = T(0.5)*tmp(0,j);
for (size_t i=1; i+1<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
arr.v(i,j) = tmp(i,j);
for (size_t j=0; j<nphi0; ++j)
arr.v(ntheta0-1,j) = 0.5*tmp(ntheta0-1,j);
arr.v(ntheta0-1,j) = T(0.5)*tmp(ntheta0-1,j);
}
vector<size_t> getIdx(const mav<T,2> &ptg) const
......@@ -218,7 +226,12 @@ template<typename T> class Interpolator
supp(ES_Kernel::get_supp(epsilon, ofactor)),
kernel(supp, ofactor, nthreads),
ncomp(separate ? slm.size() : 1),
cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1, ncomp})
#ifdef SIMD_INTERPOL
scube({ntheta+2*supp, nphi+2*supp, ncomp, (2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size()}),
cube(reinterpret_cast<T *>(scube.vdata()),{ntheta+2*supp, nphi+2*supp, ncomp, ((2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size())*native_simd<T>::size()},true)
#else
cube({ntheta+2*supp, nphi+2*supp, ncomp, 2*kmax+1})
#endif
{
MR_assert((ncomp==1)||(ncomp==3), "currently only 1 or 3 components allowed");
MR_assert(slm.size()==blm.size(), "inconsistent slm and blm vectors");
......@@ -237,7 +250,7 @@ template<typename T> class Interpolator
vector<T>lnorm(lmax+1);
for (size_t i=0; i<=lmax; ++i)
lnorm[i]=std::sqrt(4*pi/(2*i+1.));
lnorm[i]=T(std::sqrt(4*pi/(2*i+1.)));
for (size_t icomp=0; icomp<ncomp; ++icomp)
{
......@@ -245,15 +258,15 @@ template<typename T> class Interpolator
for (size_t l=m; l<=lmax; ++l)
{
if (separate)
a1(l,m) = slm[icomp](l,m)*blm[icomp](l,0).real()*T(lnorm[l]);
a1(l,m) = slm[icomp](l,m)*blm[icomp](l,0).real()*lnorm[l];
else
{
a1(l,m) = 0;
for (size_t j=0; j<slm.size(); ++j)
a1(l,m) += slm[j](l,m)*blm[j](l,0).real()*T(lnorm[l]);
a1(l,m) += slm[j](l,m)*blm[j](l,0).real()*lnorm[l];
}
}
auto m1 = cube.template subarray<2>({supp,supp,0,icomp},{ntheta,nphi,0,0});
auto m1 = cube.template subarray<2>({supp,supp,icomp,0},{ntheta,nphi,0,0});
sharp_alm2map(a1.Alms().data(), m1.vdata(), *ginfo, *ainfo, 0, nthreads);
correct(m1,0);
......@@ -268,7 +281,7 @@ template<typename T> class Interpolator
{
if (separate)
{
auto tmp = blm[icomp](l,k)*T(-2*lnorm[l]);
auto tmp = blm[icomp](l,k)*(-2*lnorm[l]);
a1(l,m) = slm[icomp](l,m)*tmp.real();
a2(l,m) = slm[icomp](l,m)*tmp.imag();
}
......@@ -277,15 +290,15 @@ template<typename T> class Interpolator
a1(l,m) = a2(l,m) = 0;
for (size_t j=0; j<slm.size(); ++j)
{
auto tmp = blm[j](l,k)*T(-2*lnorm[l]);
auto tmp = blm[j](l,k)*(-2*lnorm[l]);
a1(l,m) += slm[j](l,m)*tmp.real();
a2(l,m) += slm[j](l,m)*tmp.imag();
}
}
}
}
auto m1 = cube.template subarray<2>({supp,supp,2*k-1,icomp},{ntheta,nphi,0,0});
auto m2 = cube.template subarray<2>({supp,supp,2*k ,icomp},{ntheta,nphi,0,0});
auto m1 = cube.template subarray<2>({supp,supp,icomp,2*k-1},{ntheta,nphi,0,0});
auto m2 = cube.template subarray<2>({supp,supp,icomp,2*k },{ntheta,nphi,0,0});
sharp_alm2map_spin(k, a1.Alms().data(), a2.Alms().data(), m1.vdata(),
m2.vdata(), *ginfo, *ainfo, 0, nthreads);
correct(m1,k);
......@@ -296,23 +309,23 @@ template<typename T> class Interpolator
// fill border regions
for (size_t i=0; i<supp; ++i)
for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t k=0; k<cube.shape(3); ++k)
{
T fct = (((k+1)/2)&1) ? -1 : 1;
if (j2>=nphi) j2-=nphi;
for (size_t l=0; l<cube.shape(3); ++l)
for (size_t l=0; l<cube.shape(2); ++l)
{
cube.v(supp-1-i,j2+supp,k,l) = fct*cube(supp+1+i,j+supp,k,l);
cube.v(supp+ntheta+i,j2+supp,k,l) = fct*cube(supp+ntheta-2-i,j+supp,k,l);
cube.v(supp-1-i,j2+supp,l,k) = fct*cube(supp+1+i,j+supp,l,k);
cube.v(supp+ntheta+i,j2+supp,l,k) = fct*cube(supp+ntheta-2-i,j+supp,l,k);
}
}
for (size_t i=0; i<ntheta+2*supp; ++i)
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t l=0; l<cube.shape(3); ++l)
for (size_t k=0; k<cube.shape(3); ++k)
for (size_t l=0; l<cube.shape(2); ++l)
{
cube.v(i,j,k,l) = cube(i,j+nphi,k,l);
cube.v(i,j+nphi+supp,k,l) = cube(i,j+supp,k,l);
cube.v(i,j,l,k) = cube(i,j+nphi,l,k);
cube.v(i,j+nphi+supp,l,k) = cube(i,j+supp,l,k);
}
}
......@@ -329,7 +342,12 @@ template<typename T> class Interpolator
supp(ES_Kernel::get_supp(epsilon, ofactor)),
kernel(supp, ofactor, nthreads),
ncomp(ncomp_),
cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1, ncomp_})
#ifdef SIMD_INTERPOL
scube({ntheta+2*supp, nphi+2*supp, ncomp, (2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size()}),
cube(reinterpret_cast<T *>(scube.vdata()),{ntheta+2*supp, nphi+2*supp, ncomp, ((2*kmax+1+native_simd<T>::size()-1)/native_simd<T>::size())*native_simd<T>::size()},true)
#else
cube({ntheta+2*supp, nphi+2*supp, ncomp, 2*kmax+1})
#endif
{
MR_assert((ncomp==1)||(ncomp==3), "currently only 1 or 3 components allowed");
MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
......@@ -338,6 +356,10 @@ template<typename T> class Interpolator
void interpol (const mav<T,2> &ptg, mav<T,2> &res) const
{
#ifdef SIMD_INTERPOL
constexpr size_t vl=native_simd<T>::size();
MR_assert(scube.stride(3)==1, "bad stride");
#endif
MR_assert(!adjoint, "cannot be called in adjoint mode");
MR_assert(ptg.shape(0)==res.shape(0), "dimension mismatch");
MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
......@@ -350,6 +372,10 @@ template<typename T> class Interpolator
{
vector<T> wt(supp), wp(supp);
vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
for (auto &v:psiarr2) v=0;
#endif
while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
{
size_t i=idx[ind];
......@@ -373,30 +399,85 @@ template<typename T> class Interpolator
cnpsi=cnpsi*cpsi - snpsi*spsi;
snpsi=tmp;
}
#ifdef SIMD_INTERPOL
memcpy(reinterpret_cast<T *>(psiarr2.data()), psiarr.data(),
(2*kmax+1)*sizeof(T));
#endif
if (ncomp==1)
{
#ifndef SIMD_INTERPOL
T vv=0;
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
for (size_t l=0; l<2*kmax+1; ++l)
vv += cube(i0+j,i1+k,l,0)*wt[j]*wp[k]*psiarr[l];
vv += cube(i0+j,i1+k,0,l)*wt[j]*wp[k]*psiarr[l];
res.v(i,0) = vv;
#else
const native_simd<T> *p=&scube(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
size_t nv=scube.shape(3);
native_simd<T> vv=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
auto p2=p1;
native_simd<T> tvv=0;
for (size_t l=0; l<nv; ++l,++p2)
tvv += *p2*psiarr2[l];
vv += wtj*wp[k]*tvv;
}
}
res.v(i,0) = reduce(vv, std::plus<>());
#endif
}
else // ncomp==3
{
#ifndef SIMD_INTERPOL
T v0=0., v1=0., v2=0.;
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
for (size_t l=0; l<2*kmax+1; ++l)
{
auto tmp = wt[j]*wp[k]*psiarr[l];
v0 += cube(i0+j,i1+k,l,0)*tmp;
v1 += cube(i0+j,i1+k,l,1)*tmp;
v2 += cube(i0+j,i1+k,l,2)*tmp;
v0 += cube(i0+j,i1+k,0,l)*tmp;
v1 += cube(i0+j,i1+k,1,l)*tmp;
v2 += cube(i0+j,i1+k,2,l)*tmp;
}
res.v(i,0) = v0;
res.v(i,1) = v1;
res.v(i,2) = v2;
#else
const native_simd<T> *p=&scube(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
ptrdiff_t d2 = scube.stride(2);
size_t nv=scube.shape(3);
native_simd<T> v0=0., v1=0., v2=0.;
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
auto p2=p1;
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l,++p2)
{
auto tmp = wtjwpk*psiarr2[l];
v0 += p2[0]*tmp;
v1 += p2[d2]*tmp;
v2 += p2[2*d2]*tmp;
}
}
}
res.v(i,0) = reduce(v0, std::plus<>());
res.v(i,1) = reduce(v1, std::plus<>());
res.v(i,2) = reduce(v2, std::plus<>());
#endif
}
}
});
......@@ -407,6 +488,10 @@ template<typename T> class Interpolator
void deinterpol (const mav<T,2> &ptg, const mav<T,2> &data)
{
#ifdef SIMD_INTERPOL
constexpr size_t vl=native_simd<T>::size();
MR_assert(scube.stride(3)==1, "bad stride");
#endif
MR_assert(adjoint, "can only be called in adjoint mode");
MR_assert(ptg.shape(0)==data.shape(0), "dimension mismatch");
MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
......@@ -426,14 +511,18 @@ template<typename T> class Interpolator
size_t b_theta=99999999999999, b_phi=9999999999999999;
vector<T> wt(supp), wp(supp);
vector<T> psiarr(2*kmax+1);
#ifdef SIMD_INTERPOL
vector<native_simd<T>> psiarr2((2*kmax+1+vl-1)/vl);
for (auto &v:psiarr2) v=0;
#endif
while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
{
size_t i=idx[ind];
T f0=0.5*supp+ptg(i,0)*xdtheta;
T f0=T(0.5)*supp+ptg(i,0)*xdtheta;
size_t i0 = size_t(f0+1.);
for (size_t t=0; t<supp; ++t)
wt[t] = kernel((t+i0-f0)*delta - 1);
T f1=0.5*supp+ptg(i,1)*xdphi;
T f1=T(0.5)*supp+ptg(i,1)*xdphi;
size_t i1 = size_t(f1+1.);
for (size_t t=0; t<supp; ++t)
wp[t] = kernel((t+i1-f1)*delta - 1);
......@@ -467,16 +556,41 @@ template<typename T> class Interpolator
locks.v(b_theta+1,b_phi).lock();
locks.v(b_theta+1,b_phi+1).lock();
}
#ifdef SIMD_INTERPOL
memcpy(reinterpret_cast<T *>(psiarr2.data()), psiarr.data(),
(2*kmax+1)*sizeof(T));
#endif
if (ncomp==1)
{
#ifndef SIMD_INTERPOL
T val = data(i,0);
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
for (size_t l=0; l<2*kmax+1; ++l)
cube.v(i0+j,i1+k,l,0) += val*wt[j]*wp[k]*psiarr[l];
cube.v(i0+j,i1+k,0,l) += val*wt[j]*wp[k]*psiarr[l];
#else
native_simd<T> val = data(i,0);
native_simd<T> *p=&scube.v(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
size_t nv=scube.shape(3);
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
auto p2=p1;
native_simd<T> tv = wtj*wp[k]*val;
for (size_t l=0; l<nv; ++l,++p2)
*p2 += tv*psiarr2[l];
}
}
#endif
}
else // ncomp==3
{
#ifndef SIMD_INTERPOL
T v0=data(i,0), v1=data(i,1), v2=data(i,2);
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
......@@ -485,11 +599,36 @@ template<typename T> class Interpolator
for (size_t l=0; l<2*kmax+1; ++l)
{
T tmp = t0*psiarr[l];
cube.v(i0+j,i1+k,l,0) += v0*tmp;
cube.v(i0+j,i1+k,l,1) += v1*tmp;
cube.v(i0+j,i1+k,l,2) += v2*tmp;
cube.v(i0+j,i1+k,0,l) += v0*tmp;
cube.v(i0+j,i1+k,1,l) += v1*tmp;
cube.v(i0+j,i1+k,2,l) += v2*tmp;
}
}
#else
native_simd<T> v0=data(i,0), v1=data(i,1), v2=data(i,2);
native_simd<T> *p=&scube.v(i0,i1,0,0);
ptrdiff_t d0 = scube.stride(0);
ptrdiff_t d1 = scube.stride(1);
ptrdiff_t d2 = scube.stride(2);
size_t nv=scube.shape(3);
for (size_t j=0; j<supp; ++j, p+=d0)
{
auto p1=p;
T wtj = wt[j];
for (size_t k=0; k<supp; ++k, p1+=d1)
{
auto p2=p1;
native_simd<T> wtjwpk = wtj*wp[k];
for (size_t l=0; l<nv; ++l,++p2)
{
auto tmp = wtjwpk*psiarr2[l];
p2[0] += v0*tmp;
p2[d2] += v1*tmp;
p2[2*d2] += v2*tmp;
}
}
}
#endif
}
}
if (b_theta<locks.shape(0)) // unlock
......@@ -513,43 +652,43 @@ template<typename T> class Interpolator
// move stuff from border regions onto the main grid
for (size_t i=0; i<cube.shape(0); ++i)
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t l=0; l<cube.shape(3); ++l)
for (size_t k=0; k<cube.shape(3); ++k)
for (size_t l=0; l<cube.shape(2); ++l)
{
cube.v(i,j+nphi,k,l) += cube(i,j,k,l);
cube.v(i,j+supp,k,l) += cube(i,j+nphi+supp,k,l);
cube.v(i,j+nphi,l,k) += cube(i,j,l,k);
cube.v(i,j+supp,l,k) += cube(i,j+nphi+supp,l,k);
}
for (size_t i=0; i<supp; ++i)
for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t k=0; k<cube.shape(3); ++k)
{
T fct = (((k+1)/2)&1) ? -1 : 1;
if (j2>=nphi) j2-=nphi;
for (size_t l=0; l<cube.shape(3); ++l)
for (size_t l=0; l<cube.shape(2); ++l)
{
cube.v(supp+1+i,j+supp,k,l) += fct*cube(supp-1-i,j2+supp,k,l);
cube.v(supp+ntheta-2-i, j+supp,k,l) += fct*cube(supp+ntheta+i,j2+supp,k,l);
cube.v(supp+1+i,j+supp,l,k) += fct*cube(supp-1-i,j2+supp,l,k);
cube.v(supp+ntheta-2-i, j+supp,l,k) += fct*cube(supp+ntheta+i,j2+supp,l,k);
}
}
// special treatment for poles
for (size_t j=0,j2=nphi/2; j<nphi/2; ++j,++j2)
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t l=0; l<cube.shape(3); ++l)
for (size_t k=0; k<cube.shape(3); ++k)
for (size_t l=0; l<cube.shape(2); ++l)
{
T fct = (((k+1)/2)&1) ? -1 : 1;
if (j2>=nphi) j2-=nphi;
T tval = (cube(supp,j+supp,k,l) + fct*cube(supp,j2+supp,k,l));
cube.v(supp,j+supp,k,l) = tval;
cube.v(supp,j2+supp,k,l) = fct*tval;
tval = (cube(supp+ntheta-1,j+supp,k,l) + fct*cube(supp+ntheta-1,j2+supp,k,l));
cube.v(supp+ntheta-1,j+supp,k,l) = tval;
cube.v(supp+ntheta-1,j2+supp,k,l) = fct*tval;
T tval = (cube(supp,j+supp,l,k) + fct*cube(supp,j2+supp,l,k));
cube.v(supp,j+supp,l,k) = tval;
cube.v(supp,j2+supp,l,k) = fct*tval;
tval = (cube(supp+ntheta-1,j+supp,l,k) + fct*cube(supp+ntheta-1,j2+supp,l,k));
cube.v(supp+ntheta-1,j+supp,l,k) = tval;
cube.v(supp+ntheta-1,j2+supp,l,k) = fct*tval;
}
vector<T>lnorm(lmax+1);
for (size_t i=0; i<=lmax; ++i)
lnorm[i]=std::sqrt(4*pi/(2*i+1.));
lnorm[i]=T(std::sqrt(4*pi/(2*i+1.)));
for (size_t j=0; j<blm.size(); ++j)
slm[j].SetToZero();
......@@ -558,21 +697,21 @@ template<typename T> class Interpolator
{
bool separate = ncomp>1;
{
auto m1 = cube.template subarray<2>({supp,supp,0,icomp},{ntheta,nphi,0,0});
auto m1 = cube.template subarray<2>({supp,supp,icomp,0},{ntheta,nphi,0,0});
decorrect(m1,0);
sharp_alm2map_adjoint(a1.Alms().vdata(), m1.data(), *ginfo, *ainfo, 0, nthreads);
for (size_t m=0; m<=lmax; ++m)
for (size_t l=m; l<=lmax; ++l)
if (separate)
slm[icomp](l,m) += conj(a1(l,m))*blm[icomp](l,0).real()*T(lnorm[l]);
slm[icomp](l,m) += conj(a1(l,m))*blm[icomp](l,0).real()*lnorm[l];
else
for (size_t j=0; j<blm.size(); ++j)
slm[j](l,m) += conj(a1(l,m))*blm[j](l,0).real()*T(lnorm[l]);
slm[j](l,m) += conj(a1(l,m))*blm[j](l,0).real()*lnorm[l];
}
for (size_t k=1; k<=kmax; ++k)
{
auto m1 = cube.template subarray<2>({supp,supp,2*k-1,icomp},{ntheta,nphi,0,0});
auto m2 = cube.template subarray<2>({supp,supp,2*k ,icomp},{ntheta,nphi,0,0});
auto m1 = cube.template subarray<2>({supp,supp,icomp,2*k-1},{ntheta,nphi,0,0});
auto m2 = cube.template subarray<2>({supp,supp,icomp,2*k },{ntheta,nphi,0,0});
decorrect(m1,k);
decorrect(m2,k);
......@@ -584,14 +723,14 @@ template<typename T> class Interpolator
{
if (separate)
{
auto tmp = conj(blm[icomp](l,k))*T(-2*lnorm[l]);
auto tmp = conj(blm[icomp](l,k))*(-2*lnorm[l]);
slm[icomp](l,m) += conj(a1(l,m))*tmp.real();
slm[icomp](l,m) -= conj(a2(l,m))*tmp.imag();
}
else
for (size_t j=0; j<blm.size(); ++j)
{
auto tmp = conj(blm[j](l,k))*T(-2*lnorm[l]);
auto tmp = conj(blm[j](l,k))*(-2*lnorm[l]);
slm[j](l,m) += conj(a1(l,m))*tmp.real();
slm[j](l,m) -= conj(a2(l,m))*tmp.imag();
}
......
......@@ -14,8 +14,6 @@ namespace py = pybind11;
namespace {
using fptype = double;
template<typename T> class PyInterpolator: public Interpolator<T>
{
protected:
......@@ -244,28 +242,28 @@ PYBIND11_MODULE(pyinterpol_ng, m)
using inter_d = PyInterpolator<double>;
py::class_<inter_d> (m, "PyInterpolator", pyinterpolator_DS)
.def(py::init<const py::array &, const py::array &, bool, int64_t, int64_t, fptype, fptype, int>(),
initnormal_DS, "sky"_a, "beam"_a, "separate"_a, "lmax"_a, "kmax"_a, "epsilon"_a, "ofactor"_a=fptype(1.5),
.def(py::init<const py::array &, const py::array &, bool, int64_t, int64_t, double, double, int>(),
initnormal_DS, "sky"_a, "beam"_a, "separate"_a, "lmax"_a, "kmax"_a, "epsilon"_a, "ofactor"_a=1.5,
"nthreads"_a=0)
.def(py::init<int64_t, int64_t, int64_t, fptype, fptype, int>(), initadjoint_DS,
"lmax"_a, "kmax"_a, "ncomp"_a, "epsilon"_a, "ofactor"_a=fptype(1.5), "nthreads"_a=0)
.def(py::init<int64_t, int64_t, int64_t, double, double, int>(), initadjoint_DS,
"lmax"_a, "kmax"_a, "ncomp"_a, "epsilon"_a, "ofactor"_a=1.5, "nthreads"_a=0)
.def ("interpol", &inter_d::pyinterpol, interpol_DS, "ptg"_a)
.def ("deinterpol", &inter_d::pydeinterpol, deinterpol_DS, "ptg"_a, "data"_a)
.def ("getSlm", &inter_d::pygetSlm, getSlm_DS, "beam"_a)
.def ("support", &inter_d::support);
// using inter_f = PyInterpolator<float>;
// py::class_<inter_f> (m, "PyInterpolator_f", pyinterpolator_DS)