Commit 00384bd7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

halfway there

parent ae4c6c2f
Pipeline #70418 passed with stages
in 12 minutes and 36 seconds
......@@ -4,7 +4,7 @@ import pysharp
import time
import matplotlib.pyplot as plt
np.random.seed(48)
np.random.seed(20)
def nalm(lmax, mmax):
return ((mmax+1)*(mmax+2))//2 + (mmax+1)*(lmax-mmax)
......@@ -32,9 +32,9 @@ def convolve(alm1, alm2, lmax):
job.set_triangular_alm_info(0,0)
return job.map2alm(map)[0]*np.sqrt(4*np.pi)
lmax=10
lmax=40
mmax=lmax
kmax=lmax
kmax=10
# get random sky a_lm
......@@ -42,15 +42,15 @@ kmax=lmax
slmT = random_alm(lmax, mmax)
# build beam a_lm
blmT = random_alm(lmax, mmax)
blmT = random_alm(lmax, kmax)
t0=time.time()
# build interpolator object for slmT and blmT
foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=2)
foo = interpol_ng.PyInterpolator(slmT,blmT,lmax, kmax, epsilon=1e-6, nthreads=1)
print("setup time: ",time.time()-t0)
nth = 2*lmax+1
nph = 2*mmax+1
#exit()
ptg = np.zeros((nth,nph,3))
ptg[:,:,0] = (np.pi*np.arange(nth)/(nth-1)).reshape((-1,1))
ptg[:,:,1] = (2*np.pi*np.arange(nph)/nph).reshape((1,-1))
......@@ -62,12 +62,24 @@ print("interpolation time: ", time.time()-t0)
plt.subplot(2,2,1)
plt.imshow(bar.reshape((nth,nph)))
bar2 = np.zeros((nth,nph))
blmTfull = np.zeros(slmT.size)+0j
blmTfull[0:blmT.size] = blmT
for ith in range(nth):
for iph in range(nph):
rbeam=interpol_ng.rotate_alm(blmT, lmax, ptg[ith,iph,2],ptg[ith,iph,0],ptg[ith,iph,1])
rbeam=interpol_ng.rotate_alm(blmTfull, lmax, ptg[ith,iph,2],ptg[ith,iph,0],ptg[ith,iph,1])
bar2[ith,iph] = convolve(slmT, rbeam, lmax).real
plt.subplot(2,2,2)
plt.imshow(bar2)
plt.subplot(2,2,3)
plt.imshow((bar2-bar.reshape((nth,nph))))
plt.show()
fake = np.random.uniform(-1.,1., bar.size)
foo2 = interpol_ng.PyInterpolator(lmax, kmax, epsilon=1e-6, nthreads=2)
foo2.deinterpol(ptg.reshape((-1,3)), fake)
bla=foo2.getSlm(blmT)
print(np.vdot(slmT,bla))
slmT[lmax+1:]*=np.sqrt(2)
bla[lmax+1:]*=np.sqrt(2)
print(np.vdot(slmT,bla))
print(np.vdot(fake,bar))
......@@ -17,7 +17,7 @@
#include "alm.h"
#include "mr_util/math/fft.h"
#include "mr_util/bindings/pybind_utils.h"
#include <iostream>
using namespace std;
using namespace mr;
......@@ -42,6 +42,7 @@ template<typename T> class Interpolator
{
double sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi});
fmav<T> ftmp(tmp);
tmp.apply([](T &v){v=0.;});
auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
fmav<T> ftmp0(tmp0);
......@@ -49,7 +50,7 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) = arr(i,j);
// extend to second half
for (size_t i=1, i2=2*ntheta0-3; i+1<ntheta0; ++i,--i2)
for (size_t i=1, i2=nphi0-1; i+1<ntheta0; ++i,--i2)
for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
{
if (j2>=nphi0) j2-=nphi0;
......@@ -57,19 +58,19 @@ template<typename T> class Interpolator
}
// FFT to frequency domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{1,0},true,true,1./(nphi0*nphi0),nthreads);
for (size_t i=0; i<nphi0; ++i)
{
tmp0.v(i,nphi0-1)*=0.5;
tmp0.v(nphi0-1,i)*=0.5;
}
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
for (size_t i=0; i<nphi0; ++i)
for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
fmav<T> ftmp1(tmp1);
// zero-padded FFT in theta direction
r2r_fftpack(ftmp1,ftmp1,{0},false,false,1.,nthreads);
auto tmp2=tmp.template subarray<2>({0,0},{ntheta, nphi});
fmav<T> ftmp2(tmp2);
fmav<T> farr(arr);
// zero-padded FFT in phi direction
r2r_fftpack(ftmp2,farr,{1},false,false,1.,nthreads);
r2r_fftpack(ftmp,ftmp,{0,1},false,false,1.,nthreads);
for (size_t i=0; i<ntheta; ++i)
for (size_t j=0; j<nphi; ++j)
arr.v(i,j) = tmp(i,j);
}
void decorrect(mav<T,2> &arr, int spin)
{
......@@ -81,7 +82,7 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi; ++j)
tmp.v(i,j) = arr(i,j);
// extend to second half
for (size_t i=1, i2=2*ntheta-3; i+1<ntheta; ++i,--i2)
for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2)
{
if (j2>=nphi) j2-=nphi;
......@@ -95,7 +96,8 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
// FFT to (theta, phi) domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,1./(nphi0*nphi0),nthreads);
r2r_fftpack(ftmp0,ftmp0,{1,0},false, false,1./(nphi0*nphi0),nthreads);
arr.apply([](T &v){v=0.;});
for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
arr.v(i,j) = tmp0(i,j);
......@@ -394,6 +396,12 @@ template<typename T> class PyInterpolator: public Interpolator<T>
using Interpolator<T>::getSlmx;
using Interpolator<T>::lmax;
using Interpolator<T>::kmax;
using Interpolator<T>::nphi;
using Interpolator<T>::ntheta;
using Interpolator<T>::nphi0;
using Interpolator<T>::ntheta0;
using Interpolator<T>::correct;
using Interpolator<T>::decorrect;
py::array interpol(const py::array &ptg) const
{
auto ptg2 = to_mav<T,2>(ptg);
......@@ -419,6 +427,39 @@ slmT_.apply([](complex<T> &v){v=0;});
getSlmx(blmT, slmT);
return res;
}
py::array test_correct(const py::array &in, int spin)
{
auto in2 = to_mav<T,2>(in);
MR_assert(in2.conformable({ntheta0, nphi0}), "bad input shape");
auto res = make_Pyarr<T>({ntheta, nphi});
auto res2 = to_mav<T,2>(res,true);
res2.apply([](T &v){v=0;});
for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
res2.v(i,j) = in2(i,j);
correct (res2, spin);
return res;
}
py::array test_decorrect(const py::array &in, int spin)
{
auto in2 = to_mav<T,2>(in);
MR_assert(in2.conformable({ntheta, nphi}), "bad input shape");
auto tmp = mav<T,2>({ntheta, nphi});
for (size_t i=0; i<ntheta; ++i)
for (size_t j=0; j<nphi; ++j)
tmp.v(i,j) = in2(i,j);
decorrect (tmp, spin);
auto res = make_Pyarr<T>({ntheta0, nphi0});
auto res2 = to_mav<T,2>(res,true);
for (size_t i=0; i<ntheta0; ++i)
for (size_t j=0; j<nphi0; ++j)
res2.v(i,j) = tmp(i,j);
return res;
}
int Nphi0() const { return nphi0; }
int Ntheta0() const { return ntheta0; }
int Nphi() const { return nphi; }
int Ntheta() const { return ntheta; }
};
#if 1
......@@ -448,7 +489,13 @@ PYBIND11_MODULE(interpol_ng, m)
"lmax"_a, "kmax"_a, "epsilon"_a, "nthreads"_a)
.def ("interpol", &PyInterpolator<double>::interpol, "ptg"_a)
.def ("deinterpol", &PyInterpolator<double>::deinterpol, "ptg"_a, "data"_a)
.def ("getSlm", &PyInterpolator<double>::getSlm, "blmT"_a);
.def ("getSlm", &PyInterpolator<double>::getSlm, "blmT"_a)
.def ("test_correct", &PyInterpolator<double>::test_correct, "in"_a, "spin"_a)
.def ("test_decorrect", &PyInterpolator<double>::test_decorrect, "in"_a, "spin"_a)
.def ("Nphi", &PyInterpolator<double>::Nphi)
.def ("Ntheta", &PyInterpolator<double>::Ntheta)
.def ("Nphi0", &PyInterpolator<double>::Nphi0)
.def ("Ntheta0", &PyInterpolator<double>::Ntheta0);
#if 1
m.def("rotate_alm", &pyrotate_alm<double>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a,
"phi"_a);
......
......@@ -213,6 +213,8 @@ template<size_t ndim> class mav_info
}
bool conformable(const mav_info &other) const
{ return shp==other.shp; }
bool conformable(const shape_t &other) const
{ return shp==other; }
template<typename... Ns> ptrdiff_t idx(Ns... ns) const
{
static_assert(ndim==sizeof...(ns), "incorrect number of indices");
......@@ -231,7 +233,6 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using membuf<T>::ptr;
using mav_info<ndim>::shp;
using mav_info<ndim>::str;
using mav_info<ndim>::conformable;
using membuf<T>::rw;
using membuf<T>::vraw;
......@@ -305,6 +306,7 @@ template<typename T, size_t ndim> class mav: public mav_info<ndim>, public membu
using mav_info<ndim>::contiguous;
using mav_info<ndim>::size;
using mav_info<ndim>::idx;
using mav_info<ndim>::conformable;
mav(const T *d_, const shape_t &shp_, const stride_t &str_)
: mav_info<ndim>(shp_, str_), membuf<T>(d_) {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment