Commit 00384bd7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

halfway there

parent ae4c6c2f
...@@ -4,7 +4,7 @@ import pysharp ...@@ -4,7 +4,7 @@ import pysharp
import time import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
np.random.seed(48) np.random.seed(20)
def nalm(lmax, mmax): def nalm(lmax, mmax):
return ((mmax+1)*(mmax+2))//2 + (mmax+1)*(lmax-mmax) return ((mmax+1)*(mmax+2))//2 + (mmax+1)*(lmax-mmax)
...@@ -32,9 +32,9 @@ def convolve(alm1, alm2, lmax): ...@@ -32,9 +32,9 @@ def convolve(alm1, alm2, lmax):
job.set_triangular_alm_info(0,0) job.set_triangular_alm_info(0,0)
return job.map2alm(map)[0]*np.sqrt(4*np.pi) return job.map2alm(map)[0]*np.sqrt(4*np.pi)
lmax=10 lmax=40
mmax=lmax mmax=lmax
kmax=lmax kmax=10
# get random sky a_lm # get random sky a_lm
...@@ -42,15 +42,15 @@ kmax=lmax ...@@ -42,15 +42,15 @@ kmax=lmax
slmT = random_alm(lmax, mmax) slmT = random_alm(lmax, mmax)
# build beam a_lm # build beam a_lm
blmT = random_alm(lmax, mmax) blmT = random_alm(lmax, kmax)
t0=time.time() t0=time.time()
# build interpolator object for slmT and blmT # build interpolator object for slmT and blmT
foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=2) foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=1)
print("setup time: ",time.time()-t0) print("setup time: ",time.time()-t0)
nth = 2*lmax+1 nth = 2*lmax+1
nph = 2*mmax+1 nph = 2*mmax+1
#exit()
ptg = np.zeros((nth,nph,3)) ptg = np.zeros((nth,nph,3))
ptg[:,:,0] = (np.pi*np.arange(nth)/(nth-1)).reshape((-1,1)) ptg[:,:,0] = (np.pi*np.arange(nth)/(nth-1)).reshape((-1,1))
ptg[:,:,1] = (2*np.pi*np.arange(nph)/nph).reshape((1,-1)) ptg[:,:,1] = (2*np.pi*np.arange(nph)/nph).reshape((1,-1))
...@@ -62,12 +62,24 @@ print("interpolation time: ", time.time()-t0) ...@@ -62,12 +62,24 @@ print("interpolation time: ", time.time()-t0)
plt.subplot(2,2,1) plt.subplot(2,2,1)
plt.imshow(bar.reshape((nth,nph))) plt.imshow(bar.reshape((nth,nph)))
bar2 = np.zeros((nth,nph)) bar2 = np.zeros((nth,nph))
blmTfull = np.zeros(slmT.size)+0j
blmTfull[0:blmT.size] = blmT
for ith in range(nth): for ith in range(nth):
for iph in range(nph): for iph in range(nph):
rbeam=interpol_ng.rotate_alm(blmT, lmax, ptg[ith,iph,2],ptg[ith,iph,0],ptg[ith,iph,1]) rbeam=interpol_ng.rotate_alm(blmTfull, lmax, ptg[ith,iph,2],ptg[ith,iph,0],ptg[ith,iph,1])
bar2[ith,iph] = convolve(slmT, rbeam, lmax).real bar2[ith,iph] = convolve(slmT, rbeam, lmax).real
plt.subplot(2,2,2) plt.subplot(2,2,2)
plt.imshow(bar2) plt.imshow(bar2)
plt.subplot(2,2,3) plt.subplot(2,2,3)
plt.imshow((bar2-bar.reshape((nth,nph)))) plt.imshow((bar2-bar.reshape((nth,nph))))
plt.show() plt.show()
fake = np.random.uniform(-1.,1., bar.size)
foo2 = interpol_ng.PyInterpolator(lmax, kmax, epsilon=1e-6, nthreads=2)
foo2.deinterpol(ptg.reshape((-1,3)), fake)
bla=foo2.getSlm(blmT)
print(np.vdot(slmT,bla))
slmT[lmax+1:]*=np.sqrt(2)
bla[lmax+1:]*=np.sqrt(2)
print(np.vdot(slmT,bla))
print(np.vdot(fake,bar))
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "alm.h" #include "alm.h"
#include "mr_util/math/fft.h" #include "mr_util/math/fft.h"
#include "mr_util/bindings/pybind_utils.h" #include "mr_util/bindings/pybind_utils.h"
#include <iostream>
using namespace std; using namespace std;
using namespace mr; using namespace mr;
...@@ -42,6 +42,7 @@ template<typename T> class Interpolator ...@@ -42,6 +42,7 @@ template<typename T> class Interpolator
{ {
double sfct = (spin&1) ? -1 : 1; double sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi}); mav<T,2> tmp({nphi,nphi});
fmav<T> ftmp(tmp);
tmp.apply([](T &v){v=0.;}); tmp.apply([](T &v){v=0.;});
auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0}); auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
fmav<T> ftmp0(tmp0); fmav<T> ftmp0(tmp0);
...@@ -49,7 +50,7 @@ template<typename T> class Interpolator ...@@ -49,7 +50,7 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j) for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) = arr(i,j); tmp0.v(i,j) = arr(i,j);
// extend to second half // extend to second half
for (size_t i=1, i2=2*ntheta0-3; i+1<ntheta0; ++i,--i2) for (size_t i=1, i2=nphi0-1; i+1<ntheta0; ++i,--i2)
for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2) for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
{ {
if (j2>=nphi0) j2-=nphi0; if (j2>=nphi0) j2-=nphi0;
...@@ -57,19 +58,19 @@ template<typename T> class Interpolator ...@@ -57,19 +58,19 @@ template<typename T> class Interpolator
} }
// FFT to frequency domain on minimal grid // FFT to frequency domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{1,0},true,true,1./(nphi0*nphi0),nthreads); r2r_fftpack(ftmp0,ftmp0,{1,0},true,true,1./(nphi0*nphi0),nthreads);
for (size_t i=0; i<nphi0; ++i)
{
tmp0.v(i,nphi0-1)*=0.5;
tmp0.v(nphi0-1,i)*=0.5;
}
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads); auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (size_t i=0; i<nphi0; ++i) for (size_t i=0; i<nphi0; ++i)
for (size_t j=0; j<nphi0; ++j) for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2]; tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0}); r2r_fftpack(ftmp,ftmp,{0,1},false,false,1.,nthreads);
fmav<T> ftmp1(tmp1); for (size_t i=0; i<ntheta; ++i)
// zero-padded FFT in theta direction for (size_t j=0; j<nphi; ++j)
r2r_fftpack(ftmp1,ftmp1,{0},false,false,1.,nthreads); arr.v(i,j) = tmp(i,j);
auto tmp2=tmp.template subarray<2>({0,0},{ntheta, nphi});
fmav<T> ftmp2(tmp2);
fmav<T> farr(arr);
// zero-padded FFT in phi direction
r2r_fftpack(ftmp2,farr,{1},false,false,1.,nthreads);
} }
void decorrect(mav<T,2> &arr, int spin) void decorrect(mav<T,2> &arr, int spin)
{ {
...@@ -81,7 +82,7 @@ template<typename T> class Interpolator ...@@ -81,7 +82,7 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi; ++j) for (size_t j=0; j<nphi; ++j)
tmp.v(i,j) = arr(i,j); tmp.v(i,j) = arr(i,j);
// extend to second half // extend to second half
for (size_t i=1, i2=2*ntheta-3; i+1<ntheta; ++i,--i2) for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2) for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2)
{ {
if (j2>=nphi) j2-=nphi; if (j2>=nphi) j2-=nphi;
...@@ -95,7 +96,8 @@ template<typename T> class Interpolator ...@@ -95,7 +96,8 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j) for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2]; tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
// FFT to (theta, phi) domain on minimal grid // FFT to (theta, phi) domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,1./(nphi0*nphi0),nthreads); r2r_fftpack(ftmp0,ftmp0,{1,0},false, false,1./(nphi0*nphi0),nthreads);
arr.apply([](T &v){v=0.;});
for (size_t i=0; i<ntheta0; ++i) for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j) for (size_t j=0; j<nphi0; ++j)
arr.v(i,j) = tmp0(i,j); arr.v(i,j) = tmp0(i,j);
...@@ -394,6 +396,12 @@ template<typename T> class PyInterpolator: public Interpolator<T> ...@@ -394,6 +396,12 @@ template<typename T> class PyInterpolator: public Interpolator<T>
using Interpolator<T>::getSlmx; using Interpolator<T>::getSlmx;
using Interpolator<T>::lmax; using Interpolator<T>::lmax;
using Interpolator<T>::kmax; using Interpolator<T>::kmax;
using Interpolator<T>::nphi;
using Interpolator<T>::ntheta;
using Interpolator<T>::nphi0;
using Interpolator<T>::ntheta0;
using Interpolator<T>::correct;
using Interpolator<T>::decorrect;
py::array interpol(const py::array &ptg) const py::array interpol(const py::array &ptg) const
{ {
auto ptg2 = to_mav<T,2>(ptg); auto ptg2 = to_mav<T,2>(ptg);
...@@ -419,6 +427,39 @@ slmT_.apply([](complex<T> &v){v=0;}); ...@@ -419,6 +427,39 @@ slmT_.apply([](complex<T> &v){v=0;});
getSlmx(blmT, slmT); getSlmx(blmT, slmT);
return res; return res;
} }
py::array test_correct(const py::array &in, int spin)
{
auto in2 = to_mav<T,2>(in);
MR_assert(in2.conformable({ntheta0, nphi0}), "bad input shape");
auto res = make_Pyarr<T>({ntheta, nphi});
auto res2 = to_mav<T,2>(res,true);
res2.apply([](T &v){v=0;});
for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
res2.v(i,j) = in2(i,j);
correct (res2, spin);
return res;
}
py::array test_decorrect(const py::array &in, int spin)
{
auto in2 = to_mav<T,2>(in);
MR_assert(in2.conformable({ntheta, nphi}), "bad input shape");
auto tmp = mav<T,2>({ntheta, nphi});
for (size_t i=0; i<ntheta; ++i)
for (size_t j=0; j<nphi; ++j)
tmp.v(i,j) = in2(i,j);
decorrect (tmp, spin);
auto res = make_Pyarr<T>({ntheta0, nphi0});
auto res2 = to_mav<T,2>(res,true);
for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
res2.v(i,j) = tmp(i,j);
return res;
}
int Nphi0() const { return nphi0; }
int Ntheta0() const { return ntheta0; }
int Nphi() const { return nphi; }
int Ntheta() const { return ntheta; }
}; };
#if 1 #if 1
...@@ -448,7 +489,13 @@ PYBIND11_MODULE(interpol_ng, m) ...@@ -448,7 +489,13 @@ PYBIND11_MODULE(interpol_ng, m)
"lmax"_a, "kmax"_a, "epsilon"_a, "nthreads"_a) "lmax"_a, "kmax"_a, "epsilon"_a, "nthreads"_a)
.def ("interpol", &PyInterpolator<double>::interpol, "ptg"_a) .def ("interpol", &PyInterpolator<double>::interpol, "ptg"_a)
.def ("deinterpol", &PyInterpolator<double>::deinterpol, "ptg"_a, "data"_a) .def ("deinterpol", &PyInterpolator<double>::deinterpol, "ptg"_a, "data"_a)
.def ("getSlm", &PyInterpolator<double>::getSlm, "blmT"_a); .def ("getSlm", &PyInterpolator<double>::getSlm, "blmT"_a)
.def ("test_correct", &PyInterpolator<double>::test_correct, "in"_a, "spin"_a)
.def ("test_decorrect", &PyInterpolator<double>::test_decorrect, "in"_a, "spin"_a)
.def ("Nphi", &PyInterpolator<double>::Nphi)
.def ("Ntheta", &PyInterpolator<double>::Ntheta)
.def ("Nphi0", &PyInterpolator<double>::Nphi0)
.def ("Ntheta0", &PyInterpolator<double>::Ntheta0);
#if 1 #if 1
m.def("rotate_alm", &pyrotate_alm<double>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a, m.def("rotate_alm", &pyrotate_alm<double>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a,
"phi"_a); "phi"_a);
......
...@@ -213,6 +213,8 @@ template<size_t ndim> class mav_info ...@@ -213,6 +213,8 @@ template<size_t ndim> class mav_info
} }
bool conformable(const mav_info &other) const bool conformable(const mav_info &other) const
{ return shp==other.shp; } { return shp==other.shp; }
bool conformable(const shape_t &other) const
{ return shp==other; }
template<typename... Ns> ptrdiff_t idx(Ns... ns) const template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{ {
static_assert(ndim==sizeof...(ns), "incorrect number of indices"); static_assert(ndim==sizeof...(ns), "incorrect number of indices");
...@@ -231,7 +233,6 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -231,7 +233,6 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using membuf<T>::ptr; using membuf<T>::ptr;
using mav_info<ndim>::shp; using mav_info<ndim>::shp;
using mav_info<ndim>::str; using mav_info<ndim>::str;
using mav_info<ndim>::conformable;
using membuf<T>::rw; using membuf<T>::rw;
using membuf<T>::vraw; using membuf<T>::vraw;
...@@ -305,6 +306,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu ...@@ -305,6 +306,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using mav_info<ndim>::contiguous; using mav_info<ndim>::contiguous;
using mav_info<ndim>::size; using mav_info<ndim>::size;
using mav_info<ndim>::idx; using mav_info<ndim>::idx;
using mav_info<ndim>::conformable;
mav(const T *d_, const shape_t &shp_, const stride_t &str_) mav(const T *d_, const shape_t &shp_, const stride_t &str_)
: mav_info<ndim>(shp_, str_), membuf<T>(d_) {} : mav_info<ndim>(shp_, str_), membuf<T>(d_) {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment