Commit 9dfddec1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'adjoint' into 'master'

Adjoint 4pi convolution

See merge request mtr/cxxbase!3
parents ae4c6c2f a213c18f
...@@ -212,52 +212,63 @@ class wigner_d_risbo_openmp ...@@ -212,52 +212,63 @@ class wigner_d_risbo_openmp
template<typename T> void rotate_alm (Alm<complex<T>> &alm, template<typename T> void rotate_alm (Alm<complex<T>> &alm,
double psi, double theta, double phi) double psi, double theta, double phi)
{ {
MR_assert (alm.Lmax()==alm.Mmax(),
"rotate_alm: lmax must be equal to mmax");
auto lmax=alm.Lmax(); auto lmax=alm.Lmax();
vector<complex<double> > exppsi(lmax+1), expphi(lmax+1); MR_assert (lmax==alm.Mmax(),
for (size_t m=0; m<=lmax; ++m) "rotate_alm: lmax must be equal to mmax");
{
exppsi[m] = polar(1.,-psi*m);
expphi[m] = polar(1.,-phi*m);
}
wigner_d_risbo_openmp rec(lmax,theta);
vector<complex<double> > almtmp(lmax+1); if (theta!=0)
size_t nthreads=1;
for (size_t l=0; l<=lmax; ++l)
{ {
const auto &d(rec.recurse()); vector<complex<double> > exppsi(lmax+1), expphi(lmax+1);
for (size_t m=0; m<=lmax; ++m)
for (size_t m=0; m<=l; ++m) {
almtmp[m] = complex<double>(alm(l,0))*d(l,l+m); exppsi[m] = polar(1.,-psi*m);
expphi[m] = polar(1.,-phi*m);
execStatic(l+1, nthreads, 0, [&](Scheduler &sched) }
vector<complex<double> > almtmp(lmax+1);
size_t nthreads=1;
wigner_d_risbo_openmp rec(lmax,theta);
for (size_t l=0; l<=lmax; ++l)
{ {
auto rng=sched.getNext(); const auto &d(rec.recurse());
auto lo=rng.lo;
auto hi=rng.hi;
bool flip = true; for (size_t m=0; m<=l; ++m)
for (size_t mm=1; mm<=l; ++mm) almtmp[m] = complex<double>(alm(l,0))*d(l,l+m);
execStatic(l+1, nthreads, 0, [&](Scheduler &sched)
{ {
auto t1 = complex<double>(alm(l,mm))*exppsi[mm]; auto rng=sched.getNext();
bool flip2 = ((mm+lo)&1); auto lo=rng.lo;
for (auto m=lo; m<hi; ++m) auto hi=rng.hi;
bool flip = true;
for (size_t mm=1; mm<=l; ++mm)
{ {
double d1 = flip2 ? -d(l-mm,l-m) : d(l-mm,l-m); auto t1 = complex<double>(alm(l,mm))*exppsi[mm];
double d2 = flip ? -d(l-mm,l+m) : d(l-mm,l+m); bool flip2 = ((mm+lo)&1);
double f1 = d1+d2, f2 = d1-d2; for (auto m=lo; m<hi; ++m)
almtmp[m]+=complex<double>(t1.real()*f1,t1.imag()*f2); {
flip2 = !flip2; double d1 = flip2 ? -d(l-mm,l-m) : d(l-mm,l-m);
double d2 = flip ? -d(l-mm,l+m) : d(l-mm,l+m);
double f1 = d1+d2, f2 = d1-d2;
almtmp[m]+=complex<double>(t1.real()*f1,t1.imag()*f2);
flip2 = !flip2;
}
flip = !flip;
} }
flip = !flip; });
}
});
for (size_t m=0; m<=l; ++m) for (size_t m=0; m<=l; ++m)
alm(l,m) = complex<T>(almtmp[m]*expphi[m]); alm(l,m) = complex<T>(almtmp[m]*expphi[m]);
}
}
else
{
for (size_t m=0; m<=lmax; ++m)
{
auto ang = polar(1.,-(psi+phi)*m);
for (size_t l=m; l<=lmax; ++l)
alm(l,m) *= ang;
}
} }
} }
#endif #endif
......
...@@ -18,11 +18,17 @@ def random_alm(lmax, mmax): ...@@ -18,11 +18,17 @@ def random_alm(lmax, mmax):
return res return res
def deltabeam(lmax,kmax): def compress_alm(alm,lmax):
beam=np.zeros(nalm(lmax, kmax))+0j res = np.empty(2*len(alm)-lmax-1, dtype=np.float64)
for l in range(lmax+1): res[0:lmax+1] = alm[0:lmax+1].real
beam[l] = np.sqrt((2*l+1.)/(4*np.pi)) res[lmax+1::2] = np.sqrt(2)*alm[lmax+1:].real
return beam res[lmax+2::2] = np.sqrt(2)*alm[lmax+1:].imag
return res
def myalmdot(a1,a2,lmax,mmax,spin):
return np.vdot(compress_alm(a1,lmax),compress_alm(np.conj(a2),lmax))
def convolve(alm1, alm2, lmax): def convolve(alm1, alm2, lmax):
job = pysharp.sharpjob_d() job = pysharp.sharpjob_d()
...@@ -32,28 +38,33 @@ def convolve(alm1, alm2, lmax): ...@@ -32,28 +38,33 @@ def convolve(alm1, alm2, lmax):
job.set_triangular_alm_info(0,0) job.set_triangular_alm_info(0,0)
return job.map2alm(map)[0]*np.sqrt(4*np.pi) return job.map2alm(map)[0]*np.sqrt(4*np.pi)
lmax=10
mmax=lmax lmax=60
kmax=lmax kmax=13
# get random sky a_lm # get random sky a_lm
# the a_lm arrays follow the same conventions as those in healpy # the a_lm arrays follow the same conventions as those in healpy
slmT = random_alm(lmax, mmax) slmT = random_alm(lmax, lmax)
# build beam a_lm # build beam a_lm
blmT = random_alm(lmax, mmax) blmT = random_alm(lmax, kmax)
t0=time.time() t0=time.time()
# build interpolator object for slmT and blmT # build interpolator object for slmT and blmT
foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=2) foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=2)
print("setup time: ",time.time()-t0) print("setup time: ",time.time()-t0)
nth = 2*lmax+1 nth = lmax+1
nph = 2*mmax+1 nph = 2*lmax+1
# compute a convolved map at a fixed psi and compare it to a map convolved
# "by hand"
ptg = np.zeros((nth,nph,3)) ptg = np.zeros((nth,nph,3))
ptg[:,:,0] = (np.pi*np.arange(nth)/(nth-1)).reshape((-1,1)) ptg[:,:,0] = (np.pi*(0.5+np.arange(nth))/nth).reshape((-1,1))
ptg[:,:,1] = (2*np.pi*np.arange(nph)/nph).reshape((1,-1)) ptg[:,:,1] = (2*np.pi*(0.5+np.arange(nph))/nph).reshape((1,-1))
ptg[:,:,2] = np.pi*0.2 ptg[:,:,2] = np.pi*0.2
t0=time.time() t0=time.time()
# do the actual interpolation # do the actual interpolation
...@@ -62,12 +73,30 @@ print("interpolation time: ", time.time()-t0) ...@@ -62,12 +73,30 @@ print("interpolation time: ", time.time()-t0)
plt.subplot(2,2,1) plt.subplot(2,2,1)
plt.imshow(bar.reshape((nth,nph))) plt.imshow(bar.reshape((nth,nph)))
bar2 = np.zeros((nth,nph)) bar2 = np.zeros((nth,nph))
blmTfull = np.zeros(slmT.size)+0j
blmTfull[0:blmT.size] = blmT
for ith in range(nth): for ith in range(nth):
rbeamth=interpol_ng.rotate_alm(blmTfull, lmax, ptg[ith,0,2],ptg[ith,0,0],0)
for iph in range(nph): for iph in range(nph):
rbeam=interpol_ng.rotate_alm(blmT, lmax, ptg[ith,iph,2],ptg[ith,iph,0],ptg[ith,iph,1]) rbeam=interpol_ng.rotate_alm(rbeamth, lmax, 0, 0, ptg[ith,iph,1])
bar2[ith,iph] = convolve(slmT, rbeam, lmax).real bar2[ith,iph] = convolve(slmT, rbeam, lmax).real
plt.subplot(2,2,2) plt.subplot(2,2,2)
plt.imshow(bar2) plt.imshow(bar2)
plt.subplot(2,2,3) plt.subplot(2,2,3)
plt.imshow((bar2-bar.reshape((nth,nph)))) plt.imshow((bar2-bar.reshape((nth,nph))))
plt.show() plt.show()
ptg=np.random.uniform(0.,1.,3*1000000).reshape(1000000,3)
ptg[:,0]*=np.pi
ptg[:,1]*=2*np.pi
ptg[:,2]*=2*np.pi
foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=1)
bar=foo.interpol(ptg)
fake = np.random.uniform(0.,1., ptg.shape[0])
foo2 = interpol_ng.PyInterpolator(lmax, kmax, epsilon=1e-6, nthreads=2)
foo2.deinterpol(ptg.reshape((-1,3)), fake)
bla=foo2.getSlm(blmT)
print(myalmdot(slmT, bla, lmax, lmax, 0))
print(np.vdot(fake,bar))
print(myalmdot(slmT, bla, lmax, lmax, 0)/np.vdot(fake,bar))
...@@ -49,14 +49,20 @@ template<typename T> class Interpolator ...@@ -49,14 +49,20 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j) for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) = arr(i,j); tmp0.v(i,j) = arr(i,j);
// extend to second half // extend to second half
for (size_t i=1, i2=2*ntheta0-3; i+1<ntheta0; ++i,--i2) for (size_t i=1, i2=nphi0-1; i+1<ntheta0; ++i,--i2)
for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2) for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
{ {
if (j2>=nphi0) j2-=nphi0; if (j2>=nphi0) j2-=nphi0;
tmp0.v(i2,j) = sfct*tmp0(i,j2); tmp0.v(i2,j) = sfct*tmp0(i,j2);
} }
// FFT to frequency domain on minimal grid // FFT to frequency domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{1,0},true,true,1./(nphi0*nphi0),nthreads); r2r_fftpack(ftmp0,ftmp0,{0,1},true,true,1./(nphi0*nphi0),nthreads);
// correct amplitude at Nyquist frequency
for (size_t i=0; i<nphi0; ++i)
{
tmp0.v(i,nphi0-1)*=0.5;
tmp0.v(nphi0-1,i)*=0.5;
}
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads); auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (size_t i=0; i<nphi0; ++i) for (size_t i=0; i<nphi0; ++i)
for (size_t j=0; j<nphi0; ++j) for (size_t j=0; j<nphi0; ++j)
...@@ -81,12 +87,13 @@ template<typename T> class Interpolator ...@@ -81,12 +87,13 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi; ++j) for (size_t j=0; j<nphi; ++j)
tmp.v(i,j) = arr(i,j); tmp.v(i,j) = arr(i,j);
// extend to second half // extend to second half
for (size_t i=1, i2=2*ntheta-3; i+1<ntheta; ++i,--i2) for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2) for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2)
{ {
if (j2>=nphi) j2-=nphi; if (j2>=nphi) j2-=nphi;
tmp.v(i2,j) = sfct*tmp(i,j2); tmp.v(i2,j) = sfct*tmp(i,j2);
} }
// FIXME: faster FFT
r2r_fftpack(ftmp,ftmp,{0,1},true,true,1.,nthreads); r2r_fftpack(ftmp,ftmp,{0,1},true,true,1.,nthreads);
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads); auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0}); auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
...@@ -96,11 +103,36 @@ template<typename T> class Interpolator ...@@ -96,11 +103,36 @@ template<typename T> class Interpolator
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2]; tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
// FFT to (theta, phi) domain on minimal grid // FFT to (theta, phi) domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,1./(nphi0*nphi0),nthreads); r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,1./(nphi0*nphi0),nthreads);
for (size_t j=0; j<nphi0; ++j)
{
tmp0.v(0,j)*=0.5;
tmp0.v(ntheta0-1,j)*=0.5;
}
for (size_t i=0; i<ntheta0; ++i) for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j) for (size_t j=0; j<nphi0; ++j)
arr.v(i,j) = tmp0(i,j); arr.v(i,j) = tmp0(i,j);
} }
vector<size_t> getIdx(const mav<T,2> &ptg) const
{
vector<size_t> idx(ptg.shape(0));
constexpr size_t cellsize=16;
size_t nct = ntheta/cellsize+1,
ncp = nphi/cellsize+1;
vector<vector<size_t>> mapper(nct*ncp);
for (size_t i=0; i<ptg.shape(0); ++i)
{
size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
mapper[itheta*ncp+iphi].push_back(i);
}
size_t cnt=0;
for (const auto &vec: mapper)
for (auto i:vec)
idx[cnt++] = i;
return idx;
}
public: public:
Interpolator(const Alm<complex<T>> &slmT, const Alm<complex<T>> &blmT, Interpolator(const Alm<complex<T>> &slmT, const Alm<complex<T>> &blmT,
double epsilon, int nthreads_) double epsilon, int nthreads_)
...@@ -202,24 +234,7 @@ template<typename T> class Interpolator ...@@ -202,24 +234,7 @@ template<typename T> class Interpolator
double delta = 2./supp; double delta = 2./supp;
double xdtheta = (ntheta-1)/pi, double xdtheta = (ntheta-1)/pi,
xdphi = nphi/(2*pi); xdphi = nphi/(2*pi);
vector<size_t> idx(ptg.shape(0)); auto idx = getIdx(ptg);
{
// do some pre-sorting to improve cache use
constexpr size_t cellsize=16;
size_t nct = ntheta/cellsize+1,
ncp = nphi/cellsize+1;
vector<vector<size_t>> mapper(nct*ncp);
for (size_t i=0; i<ptg.shape(0); ++i)
{
size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
mapper[itheta*ncp+iphi].push_back(i);
}
size_t cnt=0;
for (const auto &vec: mapper)
for (auto i:vec)
idx[cnt++] = i;
}
execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched) execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
{ {
vector<T> wt(supp), wp(supp); vector<T> wt(supp), wp(supp);
...@@ -265,24 +280,7 @@ template<typename T> class Interpolator ...@@ -265,24 +280,7 @@ template<typename T> class Interpolator
double delta = 2./supp; double delta = 2./supp;
double xdtheta = (ntheta-1)/pi, double xdtheta = (ntheta-1)/pi,
xdphi = nphi/(2*pi); xdphi = nphi/(2*pi);
vector<size_t> idx(ptg.shape(0)); auto idx = getIdx(ptg);
{
// do some pre-sorting to improve cache use
constexpr size_t cellsize=16;
size_t nct = ntheta/cellsize+1,
ncp = nphi/cellsize+1;
vector<vector<size_t>> mapper(nct*ncp);
for (size_t i=0; i<ptg.shape(0); ++i)
{
size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
mapper[itheta*ncp+iphi].push_back(i);
}
size_t cnt=0;
for (const auto &vec: mapper)
for (auto i:vec)
idx[cnt++] = i;
}
execStatic(idx.size(), 1, 0, [&](Scheduler &sched) // not parallel yet execStatic(idx.size(), 1, 0, [&](Scheduler &sched) // not parallel yet
{ {
vector<T> wt(supp), wp(supp); vector<T> wt(supp), wp(supp);
...@@ -326,7 +324,7 @@ template<typename T> class Interpolator ...@@ -326,7 +324,7 @@ template<typename T> class Interpolator
auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1); auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);
// move stuff from border regions onto the main grid // move stuff from border regions onto the main grid
for (size_t i=0; i<ntheta+2*supp; ++i) for (size_t i=0; i<cube.shape(0); ++i)
for (size_t j=0; j<supp; ++j) for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<cube.shape(2); ++k) for (size_t k=0; k<cube.shape(2); ++k)
{ {
...@@ -342,7 +340,20 @@ template<typename T> class Interpolator ...@@ -342,7 +340,20 @@ template<typename T> class Interpolator
cube.v(supp+1+i,j+supp,k) += fct*cube(supp-1-i,j2+supp,k); cube.v(supp+1+i,j+supp,k) += fct*cube(supp-1-i,j2+supp,k);
cube.v(supp+ntheta-2-i, j+supp,k) += fct*cube(supp+ntheta+i,j2+supp,k); cube.v(supp+ntheta-2-i, j+supp,k) += fct*cube(supp+ntheta+i,j2+supp,k);
} }
for (size_t k=0; k<cube.shape(2); ++k)
{
double fct = (((k+1)/2)&1) ? -1 : 1;
for (size_t j=0,j2=nphi/2; j<nphi/2; ++j,++j2)
{
if (j2>=nphi) j2-=nphi;
double tval = (cube(supp,j+supp,k) + fct*cube(supp,j2+supp,k));
cube.v(supp,j+supp,k) = tval;
cube.v(supp,j2+supp,k) = fct*tval;
tval = (cube(supp+ntheta-1,j+supp,k) + fct*cube(supp+ntheta-1,j2+supp,k));
cube.v(supp+ntheta-1,j+supp,k) = tval;
cube.v(supp+ntheta-1,j2+supp,k) = fct*tval;
}
}
vector<double>lnorm(lmax+1); vector<double>lnorm(lmax+1);
for (size_t i=0; i<=lmax; ++i) for (size_t i=0; i<=lmax; ++i)
lnorm[i]=sqrt(4*pi/(2*i+1.)); lnorm[i]=sqrt(4*pi/(2*i+1.));
...@@ -372,7 +383,7 @@ template<typename T> class Interpolator ...@@ -372,7 +383,7 @@ template<typename T> class Interpolator
{ {
auto tmp = -2.*conj(blmT(l,k))*T(lnorm[l]); auto tmp = -2.*conj(blmT(l,k))*T(lnorm[l]);
slmT(l,m) += conj(a1(l,m))*tmp.real(); slmT(l,m) += conj(a1(l,m))*tmp.real();
slmT(l,m) += conj(a2(l,m))*tmp.imag(); slmT(l,m) -= conj(a2(l,m))*tmp.imag();
} }
} }
} }
...@@ -381,6 +392,13 @@ template<typename T> class Interpolator ...@@ -381,6 +392,13 @@ template<typename T> class Interpolator
template<typename T> class PyInterpolator: public Interpolator<T> template<typename T> class PyInterpolator: public Interpolator<T>
{ {
protected:
using Interpolator<T>::lmax;
using Interpolator<T>::kmax;
using Interpolator<T>::interpolx;
using Interpolator<T>::deinterpolx;
using Interpolator<T>::getSlmx;
public: public:
PyInterpolator(const py::array &slmT, const py::array &blmT, PyInterpolator(const py::array &slmT, const py::array &blmT,
int64_t lmax, int64_t kmax, double epsilon, int nthreads=0) int64_t lmax, int64_t kmax, double epsilon, int nthreads=0)
...@@ -389,11 +407,6 @@ template<typename T> class PyInterpolator: public Interpolator<T> ...@@ -389,11 +407,6 @@ template<typename T> class PyInterpolator: public Interpolator<T>
epsilon, nthreads) {} epsilon, nthreads) {}
PyInterpolator(int64_t lmax, int64_t kmax, double epsilon, int nthreads=0) PyInterpolator(int64_t lmax, int64_t kmax, double epsilon, int nthreads=0)
: Interpolator<T>(lmax, kmax, epsilon, nthreads) {} : Interpolator<T>(lmax, kmax, epsilon, nthreads) {}
using Interpolator<T>::interpolx;
using Interpolator<T>::deinterpolx;
using Interpolator<T>::getSlmx;
using Interpolator<T>::lmax;
using Interpolator<T>::kmax;
py::array interpol(const py::array &ptg) const py::array interpol(const py::array &ptg) const
{ {
auto ptg2 = to_mav<T,2>(ptg); auto ptg2 = to_mav<T,2>(ptg);
......
...@@ -39,7 +39,19 @@ def convolve(alm1, alm2, lmax): ...@@ -39,7 +39,19 @@ def convolve(alm1, alm2, lmax):
return job.map2alm(map)[0]*np.sqrt(4*np.pi) return job.map2alm(map)[0]*np.sqrt(4*np.pi)
@pmp("lkmax", [(43,43),(2,1),(30,15),(512,2)]) def compress_alm(alm,lmax):
res = np.empty(2*len(alm)-lmax-1, dtype=np.float64)
res[0:lmax+1] = alm[0:lmax+1].real
res[lmax+1::2] = np.sqrt(2)*alm[lmax+1:].real
res[lmax+2::2] = np.sqrt(2)*alm[lmax+1:].imag
return res
def myalmdot(a1,a2,lmax,mmax,spin):
return np.vdot(compress_alm(a1,lmax),compress_alm(np.conj(a2),lmax))
@pmp("lkmax", [(43,43),(2,1),(30,15),(125,2)])
def test_against_convolution(lkmax): def test_against_convolution(lkmax):
lmax, kmax = lkmax lmax, kmax = lkmax
slmT = random_alm(lmax, lmax) slmT = random_alm(lmax, lmax)
...@@ -62,3 +74,21 @@ def test_against_convolution(lkmax): ...@@ -62,3 +74,21 @@ def test_against_convolution(lkmax):
rbeam=interpol_ng.rotate_alm(blmT2, lmax, ptg[i,2],ptg[i,0],ptg[i,1]) rbeam=interpol_ng.rotate_alm(blmT2, lmax, ptg[i,2],ptg[i,0],ptg[i,1])
res2[i] = convolve(slmT, rbeam, lmax).real res2[i] = convolve(slmT, rbeam, lmax).real
_assert_close(res1, res2, 1e-7) _assert_close(res1, res2, 1e-7)
@pmp("lkmax", [(43,43),(2,1),(30,15),(125,2)])
def test_adjointness(lkmax):
lmax, kmax = lkmax
slmT = random_alm(lmax, lmax)
blmT = random_alm(lmax, kmax)
nptg=100000
ptg=np.random.uniform(0.,1.,nptg*3).reshape(nptg,3)
ptg[:,0]*=np.pi
ptg[:,1]*=2*np.pi
ptg[:,2]*=2*np.pi
foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=2)
inter1=foo.interpol(ptg)
fake = np.random.uniform(0.,1., ptg.shape[0])
foo2 = interpol_ng.PyInterpolator(lmax, kmax, epsilon=1e-6, nthreads=2)
foo2.deinterpol(ptg.reshape((-1,3)), fake)
bla=foo2.getSlm(blmT)
_assert_close(myalmdot(slmT, bla, lmax, lmax, 0), np.vdot(fake,inter1), 1e-12)
...@@ -213,6 +213,8 @@ template<size_t ndim> class mav_info ...@@ -213,6 +213,8 @@ template<size_t ndim> class mav_info
} }
bool conformable(const mav_info &other) const bool conformable(const mav_info &other) const
{ return shp==other.shp; } { return shp==other.shp; }
bool conformable(const shape_t &other) const
{ return shp==other; }
template<typename... Ns> ptrdiff_t idx(Ns... ns) const template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{ {
static_assert(ndim==sizeof...(ns), "incorrect number of indices"); static_assert(ndim==sizeof...(ns), "incorrect number of indices");
...@@ -231,7 +233,6 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -231,7 +233,6 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using membuf<T>::ptr; using membuf<T>::ptr;
using mav_info<ndim>::shp; using mav_info<ndim>::shp;
using mav_info<ndim>::str; using mav_info<ndim>::str;
using mav_info<ndim>::conformable;
using membuf<T>::rw; using membuf<T>::rw;
using membuf<T>::vraw; using membuf<T>::vraw;
...@@ -305,6 +306,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -305,6 +306,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using mav_info<ndim>::contiguous; using mav_info<ndim>::contiguous;
using mav_info<ndim>::size; using mav_info<ndim>::size;
using mav_info<ndim>::idx; using mav_info<ndim>::idx;
using mav_info<ndim>::conformable;
mav(const T *d_, const shape_t &shp_, const stride_t &str_) mav(const T *d_, const shape_t &shp_, const stride_t &str_)
: mav_info<ndim>(shp_, str_), membuf<T>(d_) {} : mav_info<ndim>(shp_, str_), membuf<T>(d_) {}
......