Commit 1f387e75 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow instantiation with float

parent 4ae3295b
......@@ -54,8 +54,8 @@ blm = random_alm(lmax, kmax, ncomp)
t0=time.time()
# build interpolator object for slmT and blmT
foo = pyinterpol_ng.PyInterpolator(slm,blm,separate,lmax, kmax, epsilon=1e-6, nthreads=2)
# build interpolator object for slm and blm
foo = pyinterpol_ng.PyInterpolator(slm,blm,separate,lmax, kmax, epsilon=1e-4, nthreads=2)
print("setup time: ",time.time()-t0)
nth = lmax+1
nph = 2*lmax+1
......@@ -98,7 +98,7 @@ t0=time.time()
bar=foo.interpol(ptg)
print("interpolation time: ", time.time()-t0)
fake = np.random.uniform(0.,1., (ptg.shape[0],ncomp2))
foo2 = pyinterpol_ng.PyInterpolator(lmax, kmax, ncomp2, epsilon=1e-6, nthreads=2)
foo2 = pyinterpol_ng.PyInterpolator(lmax, kmax, ncomp2, epsilon=1e-4, nthreads=2)
t0=time.time()
foo2.deinterpol(ptg.reshape((-1,3)), fake)
print("deinterpolation time: ", time.time()-t0)
......
......@@ -34,7 +34,7 @@ template<typename T> class Interpolator
bool adjoint;
size_t lmax, kmax, nphi0, ntheta0, nphi, ntheta;
int nthreads;
double ofactor;
T ofactor;
size_t supp;
ES_Kernel kernel;
size_t ncomp;
......@@ -42,7 +42,7 @@ template<typename T> class Interpolator
void correct(mav<T,2> &arr, int spin)
{
double sfct = (spin&1) ? -1 : 1;
T sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi});
tmp.apply([](T &v){v=0.;});
auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
......@@ -58,7 +58,7 @@ template<typename T> class Interpolator
tmp0.v(i2,j) = sfct*tmp0(i,j2);
}
// FFT to frequency domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{0,1},true,true,1./(nphi0*nphi0),nthreads);
r2r_fftpack(ftmp0,ftmp0,{0,1},true,true,T(1./(nphi0*nphi0)),nthreads);
// correct amplitude at Nyquist frequency
for (size_t i=0; i<nphi0; ++i)
{
......@@ -72,16 +72,16 @@ template<typename T> class Interpolator
auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
fmav<T> ftmp1(tmp1);
// zero-padded FFT in theta direction
r2r_fftpack(ftmp1,ftmp1,{0},false,false,1.,nthreads);
r2r_fftpack(ftmp1,ftmp1,{0},false,false,T(1),nthreads);
auto tmp2=tmp.template subarray<2>({0,0},{ntheta, nphi});
fmav<T> ftmp2(tmp2);
fmav<T> farr(arr);
// zero-padded FFT in phi direction
r2r_fftpack(ftmp2,farr,{1},false,false,1.,nthreads);
r2r_fftpack(ftmp2,farr,{1},false,false,T(1),nthreads);
}
void decorrect(mav<T,2> &arr, int spin)
{
double sfct = (spin&1) ? -1 : 1;
T sfct = (spin&1) ? -1 : 1;
mav<T,2> tmp({nphi,nphi});
fmav<T> ftmp(tmp);
......@@ -95,10 +95,10 @@ template<typename T> class Interpolator
if (j2>=nphi) j2-=nphi;
tmp.v(i2,j) = sfct*tmp(i,j2);
}
r2r_fftpack(ftmp,ftmp,{1},true,true,1.,nthreads);
r2r_fftpack(ftmp,ftmp,{1},true,true,T(1),nthreads);
auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
fmav<T> ftmp1(tmp1);
r2r_fftpack(ftmp1,ftmp1,{0},true,true,1.,nthreads);
r2r_fftpack(ftmp1,ftmp1,{0},true,true,T(1),nthreads);
auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
fmav<T> ftmp0(tmp0);
......@@ -106,7 +106,7 @@ template<typename T> class Interpolator
for (size_t j=0; j<nphi0; ++j)
tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
// FFT to (theta, phi) domain on minimal grid
r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,1./(nphi0*nphi0),nthreads);
r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,T(1./(nphi0*nphi0)),nthreads);
for (size_t j=0; j<nphi0; ++j)
{
tmp0.v(0,j)*=0.5;
......@@ -128,6 +128,7 @@ template<typename T> class Interpolator
{
size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
// MR_assert((itheta<nct)&&(iphi<ncp), "oops");
mapper[itheta*ncp+iphi].push_back(i);
}
size_t cnt=0;
......@@ -140,7 +141,7 @@ template<typename T> class Interpolator
public:
Interpolator(const vector<Alm<complex<T>>> &slm,
const vector<Alm<complex<T>>> &blm,
bool separate, double epsilon, int nthreads_)
bool separate, T epsilon, int nthreads_)
: adjoint(false),
lmax(slm.at(0).Lmax()),
kmax(blm.at(0).Mmax()),
......@@ -149,7 +150,7 @@ template<typename T> class Interpolator
nphi(std::max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
ntheta(nphi/2+1),
nthreads(nthreads_),
ofactor(double(nphi)/(2*lmax+1)),
ofactor(T(nphi)/(2*lmax+1)),
supp(ES_Kernel::get_supp(epsilon, ofactor)),
kernel(supp, ofactor, nthreads),
ncomp(separate ? slm.size() : 1),
......@@ -170,7 +171,7 @@ template<typename T> class Interpolator
auto ginfo = sharp_make_cc_geom_info(ntheta0,nphi0,0.,cube.stride(1),cube.stride(0));
auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);
vector<double>lnorm(lmax+1);
vector<T>lnorm(lmax+1);
for (size_t i=0; i<=lmax; ++i)
lnorm[i]=std::sqrt(4*pi/(2*i+1.));
......@@ -203,7 +204,7 @@ template<typename T> class Interpolator
{
if (separate)
{
auto tmp = -2.*blm[icomp](l,k)*T(lnorm[l]);
auto tmp = blm[icomp](l,k)*T(-2*lnorm[l]);
a1(l,m) = slm[icomp](l,m)*tmp.real();
a2(l,m) = slm[icomp](l,m)*tmp.imag();
}
......@@ -212,7 +213,7 @@ template<typename T> class Interpolator
a1(l,m) = a2(l,m) = 0;
for (size_t j=0; j<slm.size(); ++j)
{
auto tmp = -2.*blm[j](l,k)*T(lnorm[l]);
auto tmp = blm[j](l,k)*T(-2*lnorm[l]);
a1(l,m) += slm[j](l,m)*tmp.real();
a2(l,m) += slm[j](l,m)*tmp.imag();
}
......@@ -233,7 +234,7 @@ template<typename T> class Interpolator
for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
for (size_t k=0; k<cube.shape(2); ++k)
{
double fct = (((k+1)/2)&1) ? -1 : 1;
T fct = (((k+1)/2)&1) ? -1 : 1;
if (j2>=nphi) j2-=nphi;
for (size_t l=0; l<cube.shape(3); ++l)
{
......@@ -251,7 +252,7 @@ template<typename T> class Interpolator
}
}
Interpolator(size_t lmax_, size_t kmax_, size_t ncomp_, double epsilon, int nthreads_)
Interpolator(size_t lmax_, size_t kmax_, size_t ncomp_, T epsilon, int nthreads_)
: adjoint(true),
lmax(lmax_),
kmax(kmax_),
......@@ -260,7 +261,7 @@ template<typename T> class Interpolator
nphi(std::max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
ntheta(nphi/2+1),
nthreads(nthreads_),
ofactor(double(nphi)/(2*lmax+1)),
ofactor(T(nphi)/(2*lmax+1)),
supp(ES_Kernel::get_supp(epsilon, ofactor)),
kernel(supp, ofactor, nthreads),
ncomp(ncomp_),
......@@ -277,9 +278,9 @@ template<typename T> class Interpolator
MR_assert(ptg.shape(0)==res.shape(0), "dimension mismatch");
MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
MR_assert(res.shape(1)==ncomp, "# of components mismatch");
double delta = 2./supp;
double xdtheta = (ntheta-1)/pi,
xdphi = nphi/(2*pi);
T delta = T(2)/supp;
T xdtheta = T((ntheta-1)/pi),
xdphi = T(nphi/(2*pi));
auto idx = getIdx(ptg);
execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
{
......@@ -288,11 +289,11 @@ template<typename T> class Interpolator
while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
{
size_t i=idx[ind];
double f0=0.5*supp+ptg(i,0)*xdtheta;
size_t i0 = size_t(f0+1.);
T f0=T(0.5*supp+ptg(i,0)*xdtheta);
size_t i0 = size_t(f0+T(1));
for (size_t t=0; t<supp; ++t)
wt[t] = kernel((t+i0-f0)*delta - 1);
double f1=0.5*supp+ptg(i,1)*xdphi;
T f1=T(0.5)*supp+ptg(i,1)*xdphi;
size_t i1 = size_t(f1+1.);
for (size_t t=0; t<supp; ++t)
wp[t] = kernel((t+i1-f1)*delta - 1);
......@@ -302,39 +303,33 @@ template<typename T> class Interpolator
double cnpsi=cpsi, snpsi=spsi;
for (size_t l=1; l<=kmax; ++l)
{
psiarr[2*l-1]=cnpsi;
psiarr[2*l]=snpsi;
psiarr[2*l-1]=T(cnpsi);
psiarr[2*l]=T(snpsi);
const double tmp = snpsi*cpsi + cnpsi*spsi;
cnpsi=cnpsi*cpsi - snpsi*spsi;
snpsi=tmp;
}
if (ncomp==1)
{
double vv=0.;
T vv=0;
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
{
double t0=wt[j]*wp[k];
for (size_t l=0; l<2*kmax+1; ++l)
vv += cube(i0+j,i1+k,l,0)*t0*psiarr[l];
}
vv += cube(i0+j,i1+k,l,0)*wt[j]*wp[k]*psiarr[l];
res.v(i,0) = vv;
}
else // ncomp==3
{
double v0=0., v1=0., v2=0.;
T v0=0., v1=0., v2=0.;
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
{
double t0 = wt[j]*wp[k];
for (size_t l=0; l<2*kmax+1; ++l)
{
auto tmp = t0*psiarr[l];
auto tmp = wt[j]*wp[k]*psiarr[l];
v0 += cube(i0+j,i1+k,l,0)*tmp;
v1 += cube(i0+j,i1+k,l,1)*tmp;
v2 += cube(i0+j,i1+k,l,2)*tmp;
}
}
res.v(i,0) = v0;
res.v(i,1) = v1;
res.v(i,2) = v2;
......@@ -349,9 +344,9 @@ template<typename T> class Interpolator
MR_assert(ptg.shape(0)==data.shape(0), "dimension mismatch");
MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
MR_assert(data.shape(1)==ncomp, "# of components mismatch");
double delta = 2./supp;
double xdtheta = (ntheta-1)/pi,
xdphi = nphi/(2*pi);
T delta = T(2)/supp;
T xdtheta = T((ntheta-1)/pi),
xdphi = T(nphi/(2*pi));
auto idx = getIdx(ptg);
constexpr size_t cellsize=16;
......@@ -367,11 +362,11 @@ template<typename T> class Interpolator
while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
{
size_t i=idx[ind];
double f0=0.5*supp+ptg(i,0)*xdtheta;
T f0=0.5*supp+ptg(i,0)*xdtheta;
size_t i0 = size_t(f0+1.);
for (size_t t=0; t<supp; ++t)
wt[t] = kernel((t+i0-f0)*delta - 1);
double f1=0.5*supp+ptg(i,1)*xdphi;
T f1=0.5*supp+ptg(i,1)*xdphi;
size_t i1 = size_t(f1+1.);
for (size_t t=0; t<supp; ++t)
wp[t] = kernel((t+i1-f1)*delta - 1);
......@@ -381,8 +376,8 @@ template<typename T> class Interpolator
double cnpsi=cpsi, snpsi=spsi;
for (size_t l=1; l<=kmax; ++l)
{
psiarr[2*l-1]=cnpsi;
psiarr[2*l]=snpsi;
psiarr[2*l-1]=T(cnpsi);
psiarr[2*l]=T(snpsi);
const double tmp = snpsi*cpsi + cnpsi*spsi;
cnpsi=cnpsi*cpsi - snpsi*spsi;
snpsi=tmp;
......@@ -407,7 +402,7 @@ template<typename T> class Interpolator
}
if (ncomp==1)
{
double val = data(i,0);
T val = data(i,0);
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
for (size_t l=0; l<2*kmax+1; ++l)
......@@ -415,14 +410,14 @@ template<typename T> class Interpolator
}
else // ncomp==3
{
double v0=data(i,0), v1=data(i,1), v2=data(i,2);
T v0=data(i,0), v1=data(i,1), v2=data(i,2);
for (size_t j=0; j<supp; ++j)
for (size_t k=0; k<supp; ++k)
{
double t0 = wt[j]*wp[k];
T t0 = wt[j]*wp[k];
for (size_t l=0; l<2*kmax+1; ++l)
{
double tmp = t0*psiarr[l];
T tmp = t0*psiarr[l];
cube.v(i0+j,i1+k,l,0) += v0*tmp;
cube.v(i0+j,i1+k,l,1) += v1*tmp;
cube.v(i0+j,i1+k,l,2) += v2*tmp;
......@@ -461,7 +456,7 @@ template<typename T> class Interpolator
for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
for (size_t k=0; k<cube.shape(2); ++k)
{
double fct = (((k+1)/2)&1) ? -1 : 1;
T fct = (((k+1)/2)&1) ? -1 : 1;
if (j2>=nphi) j2-=nphi;
for (size_t l=0; l<cube.shape(3); ++l)
{
......@@ -475,9 +470,9 @@ template<typename T> class Interpolator
for (size_t k=0; k<cube.shape(2); ++k)
for (size_t l=0; l<cube.shape(3); ++l)
{
double fct = (((k+1)/2)&1) ? -1 : 1;
T fct = (((k+1)/2)&1) ? -1 : 1;
if (j2>=nphi) j2-=nphi;
double tval = (cube(supp,j+supp,k,l) + fct*cube(supp,j2+supp,k,l));
T tval = (cube(supp,j+supp,k,l) + fct*cube(supp,j2+supp,k,l));
cube.v(supp,j+supp,k,l) = tval;
cube.v(supp,j2+supp,k,l) = fct*tval;
tval = (cube(supp+ntheta-1,j+supp,k,l) + fct*cube(supp+ntheta-1,j2+supp,k,l));
......@@ -485,7 +480,7 @@ template<typename T> class Interpolator
cube.v(supp+ntheta-1,j2+supp,k,l) = fct*tval;
}
vector<double>lnorm(lmax+1);
vector<T>lnorm(lmax+1);
for (size_t i=0; i<=lmax; ++i)
lnorm[i]=std::sqrt(4*pi/(2*i+1.));
......@@ -522,14 +517,14 @@ template<typename T> class Interpolator
{
if (separate)
{
auto tmp = -2.*conj(blm[icomp](l,k))*T(lnorm[l]);
auto tmp = conj(blm[icomp](l,k))*T(-2*lnorm[l]);
slm[icomp](l,m) += conj(a1(l,m))*tmp.real();
slm[icomp](l,m) -= conj(a2(l,m))*tmp.imag();
}
else
for (size_t j=0; j<blm.size(); ++j)
{
auto tmp = -2.*conj(blm[j](l,k))*T(lnorm[l]);
auto tmp = conj(blm[j](l,k))*T(-2*lnorm[l]);
slm[j](l,m) += conj(a1(l,m))*tmp.real();
slm[j](l,m) -= conj(a2(l,m))*tmp.imag();
}
......
......@@ -14,6 +14,8 @@ namespace py = pybind11;
namespace {
using fptype = float;
template<typename T> class PyInterpolator: public Interpolator<T>
{
protected:
......@@ -43,19 +45,19 @@ void makevec_v(py::array &inp, int64_t lmax, int64_t kmax, vector<Alm<complex<T>
}
public:
PyInterpolator(const py::array &slm, const py::array &blm,
bool separate, int64_t lmax, int64_t kmax, double epsilon, int nthreads=0)
bool separate, int64_t lmax, int64_t kmax, T epsilon, int nthreads=0)
: Interpolator<T>(makevec(slm, lmax, lmax),
makevec(blm, lmax, kmax),
separate, epsilon, nthreads) {}
PyInterpolator(int64_t lmax, int64_t kmax, int64_t ncomp_, double epsilon, int nthreads=0)
PyInterpolator(int64_t lmax, int64_t kmax, int64_t ncomp_, T epsilon, int nthreads=0)
: Interpolator<T>(lmax, kmax, ncomp_, epsilon, nthreads) {}
py::array pyinterpol(const py::array &ptg) const
{
auto ptg2 = to_mav<T,2>(ptg);
auto res = make_Pyarr<double>({ptg2.shape(0),ncomp});
auto res2 = to_mav<double,2>(res,true);
auto res = make_Pyarr<T>({ptg2.shape(0),ncomp});
auto res2 = to_mav<T,2>(res,true);
interpol(ptg2, res2);
return res;
return move(res);
}
void pydeinterpol(const py::array &ptg, const py::array &data)
......@@ -68,10 +70,10 @@ void makevec_v(py::array &inp, int64_t lmax, int64_t kmax, vector<Alm<complex<T>
{
auto blm = makevec(blm_, lmax, kmax);
auto res = make_Pyarr<complex<T>>({Alm_Base::Num_Alms(lmax, lmax),blm.size()});
vector<Alm<complex<double>>> slm;
vector<Alm<complex<T>>> slm;
makevec_v(res, lmax, lmax, slm);
getSlm(blm, slm);
return res;
return move(res);
}
};
......@@ -85,7 +87,7 @@ template<typename T> py::array pyrotate_alm(const py::array &alm_, int64_t lmax,
for (size_t i=0; i<a1.shape(0); ++i) a2.v(i)=a1(i);
auto tmp = Alm<complex<T>>(a2,lmax,lmax);
rotate_alm(tmp, psi, theta, phi);
return alm;
return move(alm);
}
#endif
......@@ -227,17 +229,17 @@ PYBIND11_MODULE(pyinterpol_ng, m)
m.doc() = pyinterpol_ng_DS;
py::class_<PyInterpolator<double>> (m, "PyInterpolator", pyinterpolator_DS)
.def(py::init<const py::array &, const py::array &, bool, int64_t, int64_t, double, int>(),
py::class_<PyInterpolator<fptype>> (m, "PyInterpolator", pyinterpolator_DS)
.def(py::init<const py::array &, const py::array &, bool, int64_t, int64_t, fptype, int>(),
initnormal_DS, "sky"_a, "beam"_a, "separate"_a, "lmax"_a, "kmax"_a, "epsilon"_a,
"nthreads"_a)
.def(py::init<int64_t, int64_t, int64_t, double, int>(), initadjoint_DS,
.def(py::init<int64_t, int64_t, int64_t, fptype, int>(), initadjoint_DS,
"lmax"_a, "kmax"_a, "ncomp"_a, "epsilon"_a, "nthreads"_a)
.def ("interpol", &PyInterpolator<double>::pyinterpol, interpol_DS, "ptg"_a)
.def ("deinterpol", &PyInterpolator<double>::pydeinterpol, deinterpol_DS, "ptg"_a, "data"_a)
.def ("getSlm", &PyInterpolator<double>::pygetSlm, getSlm_DS, "beam"_a);
.def ("interpol", &PyInterpolator<fptype>::pyinterpol, interpol_DS, "ptg"_a)
.def ("deinterpol", &PyInterpolator<fptype>::pydeinterpol, deinterpol_DS, "ptg"_a, "data"_a)
.def ("getSlm", &PyInterpolator<fptype>::pygetSlm, getSlm_DS, "beam"_a);
#if 1
m.def("rotate_alm", &pyrotate_alm<double>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a,
m.def("rotate_alm", &pyrotate_alm<fptype>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a,
"phi"_a);
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment