interpol_ng.cc 17.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*
 *  Copyright (C) 2020 Max-Planck-Society
 *  Author: Martin Reinecke
 */

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <vector>
#include <complex>
#include "mr_util/math/constants.h"
#include "mr_util/math/gl_integrator.h"
#include "mr_util/math/es_kernel.h"
#include "mr_util/infra/mav.h"
#include "mr_util/sharp/sharp.h"
#include "mr_util/sharp/sharp_almhelpers.h"
#include "mr_util/sharp/sharp_geomhelpers.h"
#include "alm.h"
#include "mr_util/math/fft.h"
#include "mr_util/bindings/pybind_utils.h"
Martin Reinecke's avatar
Martin Reinecke committed
20

21
22
23
24
25
26
27
using namespace std;
using namespace mr;

namespace py = pybind11;

namespace {

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
28
29
constexpr double ofmin=1.5;

30
31
32
template<typename T> class Interpolator
  {
  protected:
33
    bool adjoint;
Martin Reinecke's avatar
Martin Reinecke committed
34
    size_t lmax, kmax, nphi0, ntheta0, nphi, ntheta;
Martin Reinecke's avatar
Martin Reinecke committed
35
    int nthreads;
Martin Reinecke's avatar
fix    
Martin Reinecke committed
36
37
    double ofactor;
    size_t supp;
38
39
40
    ES_Kernel kernel;
    mav<T,3> cube; // the data cube (theta, phi, 2*mbeam+1[, IQU])

41
    void correct(mav<T,2> &arr, int spin)
42
      {
43
      double sfct = (spin&1) ? -1 : 1;
44
      mav<T,2> tmp({nphi,nphi});
Martin Reinecke's avatar
Martin Reinecke committed
45
      tmp.apply([](T &v){v=0.;});
Martin Reinecke's avatar
Martin Reinecke committed
46
      auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
Martin Reinecke's avatar
Martin Reinecke committed
47
48
49
50
      fmav<T> ftmp0(tmp0);
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) = arr(i,j);
51
      // extend to second half
Martin Reinecke's avatar
Martin Reinecke committed
52
      for (size_t i=1, i2=nphi0-1; i+1<ntheta0; ++i,--i2)
Martin Reinecke's avatar
Martin Reinecke committed
53
        for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
54
          {
Martin Reinecke's avatar
Martin Reinecke committed
55
          if (j2>=nphi0) j2-=nphi0;
Martin Reinecke's avatar
Martin Reinecke committed
56
          tmp0.v(i2,j) = sfct*tmp0(i,j2);
57
          }
Martin Reinecke's avatar
Martin Reinecke committed
58
      // FFT to frequency domain on minimal grid
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
59
60
      r2r_fftpack(ftmp0,ftmp0,{0,1},true,true,1./(nphi0*nphi0),nthreads);
      // correct amplitude at Nyquist frequency
Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
64
65
      for (size_t i=0; i<nphi0; ++i)
        {
        tmp0.v(i,nphi0-1)*=0.5;
        tmp0.v(nphi0-1,i)*=0.5;
        }
Martin Reinecke's avatar
Martin Reinecke committed
66
      auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
67
68
69
      for (size_t i=0; i<nphi0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
70
71
72
73
74
75
76
77
78
      auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
      fmav<T> ftmp1(tmp1);
      // zero-padded FFT in theta direction
      r2r_fftpack(ftmp1,ftmp1,{0},false,false,1.,nthreads);
      auto tmp2=tmp.template subarray<2>({0,0},{ntheta, nphi});
      fmav<T> ftmp2(tmp2);
      fmav<T> farr(arr);
      // zero-padded FFT in phi direction
      r2r_fftpack(ftmp2,farr,{1},false,false,1.,nthreads);
79
      }
80
81
82
83
84
85
86
87
88
89
    void decorrect(mav<T,2> &arr, int spin)
      {
      double sfct = (spin&1) ? -1 : 1;
      mav<T,2> tmp({nphi,nphi});
      fmav<T> ftmp(tmp);

      for (size_t i=0; i<ntheta; ++i)
        for (size_t j=0; j<nphi; ++j)
          tmp.v(i,j) = arr(i,j);
      // extend to second half
Martin Reinecke's avatar
Martin Reinecke committed
90
      for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
91
92
93
94
95
        for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2)
          {
          if (j2>=nphi) j2-=nphi;
          tmp.v(i2,j) = sfct*tmp(i,j2);
          }
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
96
97
98
99
      r2r_fftpack(ftmp,ftmp,{1},true,true,1.,nthreads);
      auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
      fmav<T> ftmp1(tmp1);
      r2r_fftpack(ftmp1,ftmp1,{0},true,true,1.,nthreads);
100
101
102
103
104
105
106
      auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
      auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
      fmav<T> ftmp0(tmp0);
      for (size_t i=0; i<nphi0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
      // FFT to (theta, phi) domain on minimal grid
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
107
      r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,1./(nphi0*nphi0),nthreads);
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
108
109
110
111
112
      for (size_t j=0; j<nphi0; ++j)
        {
        tmp0.v(0,j)*=0.5;
        tmp0.v(ntheta0-1,j)*=0.5;
        }
113
114
115
116
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          arr.v(i,j) = tmp0(i,j);
      }
117

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    vector<size_t> getIdx(const mav<T,2> &ptg) const
      {
      vector<size_t> idx(ptg.shape(0));
      constexpr size_t cellsize=16;
      size_t nct = ntheta/cellsize+1,
             ncp = nphi/cellsize+1;
      vector<vector<size_t>> mapper(nct*ncp);
      for (size_t i=0; i<ptg.shape(0); ++i)
        {
        size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
               iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
        mapper[itheta*ncp+iphi].push_back(i);
        }
      size_t cnt=0;
      for (const auto &vec: mapper)
        for (auto i:vec)
          idx[cnt++] = i;
      return idx;
      }

138
139
  public:
    Interpolator(const Alm<complex<T>> &slmT, const Alm<complex<T>> &blmT,
Martin Reinecke's avatar
Martin Reinecke committed
140
      double epsilon, int nthreads_)
141
142
      : adjoint(false),
        lmax(slmT.Lmax()),
143
        kmax(blmT.Mmax()),
Martin Reinecke's avatar
Martin Reinecke committed
144
145
        nphi0(2*good_size_real(lmax+1)),
        ntheta0(nphi0/2+1),
146
        nphi(max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
Martin Reinecke's avatar
Martin Reinecke committed
147
        ntheta(nphi/2+1),
Martin Reinecke's avatar
Martin Reinecke committed
148
        nthreads(nthreads_),
Martin Reinecke's avatar
Martin Reinecke committed
149
        ofactor(double(nphi)/(2*lmax+1)),
Martin Reinecke's avatar
fix    
Martin Reinecke committed
150
        supp(ES_Kernel::get_supp(epsilon, ofactor)),
Martin Reinecke's avatar
Martin Reinecke committed
151
        kernel(supp, ofactor, nthreads),
152
153
154
155
156
157
        cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1})
      {
      MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
      MR_assert(slmT.Mmax()==lmax, "Sky lmax must be equal to Sky mmax");
      MR_assert(blmT.Lmax()==lmax, "Sky and beam lmax must be equal");
      Alm<complex<T>> a1(lmax, lmax), a2(lmax,lmax);
Martin Reinecke's avatar
Martin Reinecke committed
158
      auto ginfo = sharp_make_cc_geom_info(ntheta0,nphi0,0.,cube.stride(1),cube.stride(0));
159
160
161
162
163
164
      auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);

      vector<double>lnorm(lmax+1);
      for (size_t i=0; i<=lmax; ++i)
        lnorm[i]=sqrt(4*pi/(2*i+1.));

165
166
167
168
169
170
171
172
173
      {
      for (size_t m=0; m<=lmax; ++m)
        for (size_t l=m; l<=lmax; ++l)
          a1(l,m) = slmT(l,m)*blmT(l,0).real()*T(lnorm[l]);
      auto m1 = cube.template subarray<2>({supp,supp,0},{ntheta,nphi,0});
      sharp_alm2map(a1.Alms().data(), m1.vdata(), *ginfo, *ainfo, 0, nthreads);
      correct(m1,0);
      }
      for (size_t k=1; k<=kmax; ++k)
174
175
176
177
178
179
180
181
        {
        for (size_t m=0; m<=lmax; ++m)
          for (size_t l=m; l<=lmax; ++l)
            {
            if (l<k)
              a1(l,m)=a2(l,m)=0.;
            else
              {
182
              auto tmp = -2.*blmT(l,k)*T(lnorm[l]);
Martin Reinecke's avatar
Martin Reinecke committed
183
              a1(l,m) = slmT(l,m)*tmp.real();
184
              a2(l,m) = slmT(l,m)*tmp.imag();
185
186
              }
            }
187
188
189
190
        auto m1 = cube.template subarray<2>({supp,supp,2*k-1},{ntheta,nphi,0});
        auto m2 = cube.template subarray<2>({supp,supp,2*k  },{ntheta,nphi,0});
        sharp_alm2map_spin(k, a1.Alms().data(), a2.Alms().data(), m1.vdata(),
          m2.vdata(), *ginfo, *ainfo, 0, nthreads);
191
        correct(m1,k);
192
        correct(m2,k);
193
194
195
196
197
198
        }
      // fill border regions
      for (size_t i=0; i<supp; ++i)
        for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
199
            double fct = (((k+1)/2)&1) ? -1 : 1;
200
            if (j2>=nphi) j2-=nphi;
201
202
            cube.v(supp-1-i,j2+supp,k) = fct*cube(supp+1+i,j+supp,k);
            cube.v(supp+ntheta+i,j2+supp,k) = fct*cube(supp+ntheta-2-i, j+supp,k);
203
204
205
206
207
208
209
210
211
212
            }
      for (size_t i=0; i<ntheta+2*supp; ++i)
        for (size_t j=0; j<supp; ++j)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
            cube.v(i,j,k) = cube(i,j+nphi,k);
            cube.v(i,j+nphi+supp,k) = cube(i,j+supp,k);
            }
      }

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    Interpolator(size_t lmax_, size_t kmax_, double epsilon, int nthreads_)
      : adjoint(true),
        lmax(lmax_),
        kmax(kmax_),
        nphi0(2*good_size_real(lmax+1)),
        ntheta0(nphi0/2+1),
        nphi(max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
        ntheta(nphi/2+1),
        nthreads(nthreads_),
        ofactor(double(nphi)/(2*lmax+1)),
        supp(ES_Kernel::get_supp(epsilon, ofactor)),
        kernel(supp, ofactor, nthreads),
        cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1})
      {
      MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
      cube.apply([](T &v){v=0.;});
      }

231
232
    void interpolx (const mav<T,2> &ptg, mav<T,1> &res) const
      {
233
      MR_assert(!adjoint, "cannot be called in adjoint mode");
234
235
      MR_assert(ptg.shape(0)==res.shape(0), "dimension mismatch");
      MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
Martin Reinecke's avatar
Martin Reinecke committed
236
237
238
      double delta = 2./supp;
      double xdtheta = (ntheta-1)/pi,
             xdphi = nphi/(2*pi);
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
239
      auto idx = getIdx(ptg);
Martin Reinecke's avatar
Martin Reinecke committed
240
      execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
241
        {
Martin Reinecke's avatar
Martin Reinecke committed
242
243
244
        vector<T> wt(supp), wp(supp);
        vector<T> psiarr(2*kmax+1);
        while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
245
          {
Martin Reinecke's avatar
Martin Reinecke committed
246
247
248
249
250
251
252
253
254
255
256
          size_t i=idx[ind];
          double f0=0.5*supp+ptg(i,0)*xdtheta;
          size_t i0 = size_t(f0+1.);
          for (size_t t=0; t<supp; ++t)
            wt[t] = kernel((t+i0-f0)*delta - 1);
          double f1=0.5*supp+ptg(i,1)*xdphi;
          size_t i1 = size_t(f1+1.);
          for (size_t t=0; t<supp; ++t)
            wp[t] = kernel((t+i1-f1)*delta - 1);
          double val=0;
          psiarr[0]=1.;
257
258
          double psi=ptg(i,2);
          double cpsi=cos(psi), spsi=sin(psi);
Martin Reinecke's avatar
Martin Reinecke committed
259
260
261
262
          double cnpsi=cpsi, snpsi=spsi;
          for (size_t l=1; l<=kmax; ++l)
            {
            psiarr[2*l-1]=cnpsi;
263
            psiarr[2*l]=snpsi;
Martin Reinecke's avatar
Martin Reinecke committed
264
265
266
267
268
269
270
271
272
            const double tmp = snpsi*cpsi + cnpsi*spsi;
            cnpsi=cnpsi*cpsi - snpsi*spsi;
            snpsi=tmp;
            }
          for (size_t j=0; j<supp; ++j)
            for (size_t k=0; k<supp; ++k)
              for (size_t l=0; l<2*kmax+1; ++l)
                val += cube(i0+j,i1+k,l)*wt[j]*wp[k]*psiarr[l];
          res.v(i) = val;
273
          }
Martin Reinecke's avatar
Martin Reinecke committed
274
        });
275
      }
276
277
278
279
280
281
282
283
284

    void deinterpolx (const mav<T,2> &ptg, const mav<T,1> &data)
      {
      MR_assert(adjoint, "can only be called in adjoint mode");
      MR_assert(ptg.shape(0)==data.shape(0), "dimension mismatch");
      MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
      double delta = 2./supp;
      double xdtheta = (ntheta-1)/pi,
             xdphi = nphi/(2*pi);
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
285
      auto idx = getIdx(ptg);
286
287
288
289
290
291
292

      constexpr size_t cellsize=16;
      size_t nct = ntheta/cellsize+5,
             ncp = nphi/cellsize+5;
      mav<std::mutex,2> locks({nct,ncp});

      execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
293
        {
294
        size_t b_theta=99999999999999, b_phi=9999999999999999;
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        vector<T> wt(supp), wp(supp);
        vector<T> psiarr(2*kmax+1);
        while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
          {
          size_t i=idx[ind];
          double f0=0.5*supp+ptg(i,0)*xdtheta;
          size_t i0 = size_t(f0+1.);
          for (size_t t=0; t<supp; ++t)
            wt[t] = kernel((t+i0-f0)*delta - 1);
          double f1=0.5*supp+ptg(i,1)*xdphi;
          size_t i1 = size_t(f1+1.);
          for (size_t t=0; t<supp; ++t)
            wp[t] = kernel((t+i1-f1)*delta - 1);
          double val=data(i);
          psiarr[0]=1.;
          double psi=ptg(i,2);
          double cpsi=cos(psi), spsi=sin(psi);
          double cnpsi=cpsi, snpsi=spsi;
          for (size_t l=1; l<=kmax; ++l)
            {
            psiarr[2*l-1]=cnpsi;
            psiarr[2*l]=snpsi;
            const double tmp = snpsi*cpsi + cnpsi*spsi;
            cnpsi=cnpsi*cpsi - snpsi*spsi;
            snpsi=tmp;
            }
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
          size_t b_theta_new = i0/cellsize,
                 b_phi_new = i1/cellsize;
          if ((b_theta_new!=b_theta) || (b_phi_new!=b_phi))
            {
            if (b_theta<locks.shape(0))  // unlock
              {
              locks.v(b_theta,b_phi).unlock();
              locks.v(b_theta,b_phi+1).unlock();
              locks.v(b_theta+1,b_phi).unlock();
              locks.v(b_theta+1,b_phi+1).unlock();
              }
            b_theta = b_theta_new;
            b_phi = b_phi_new;
            locks.v(b_theta,b_phi).lock();
            locks.v(b_theta,b_phi+1).lock();
            locks.v(b_theta+1,b_phi).lock();
            locks.v(b_theta+1,b_phi+1).lock();
            }
339
340
341
342
343
          for (size_t j=0; j<supp; ++j)
            for (size_t k=0; k<supp; ++k)
              for (size_t l=0; l<2*kmax+1; ++l)
                cube.v(i0+j,i1+k,l) += val*wt[j]*wp[k]*psiarr[l];
          }
344
345
346
347
348
349
350
        if (b_theta<locks.shape(0))  // unlock
          {
          locks.v(b_theta,b_phi).unlock();
          locks.v(b_theta,b_phi+1).unlock();
          locks.v(b_theta+1,b_phi).unlock();
          locks.v(b_theta+1,b_phi+1).unlock();
          }
351
352
353
354
355
356
357
358
359
360
        });
      }
    void getSlmx (const Alm<complex<T>> &blmT, Alm<complex<T>> &slmT)
      {
      MR_assert(adjoint, "can only be called in adjoint mode");
      Alm<complex<T>> a1(lmax, lmax), a2(lmax,lmax);
      auto ginfo = sharp_make_cc_geom_info(ntheta0,nphi0,0.,cube.stride(1),cube.stride(0));
      auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);

      // move stuff from border regions onto the main grid
361
      for (size_t i=0; i<cube.shape(0); ++i)
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        for (size_t j=0; j<supp; ++j)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
            cube.v(i,j+nphi,k) += cube(i,j,k);
            cube.v(i,j+supp,k) += cube(i,j+nphi+supp,k);
            }
      for (size_t i=0; i<supp; ++i)
        for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
            double fct = (((k+1)/2)&1) ? -1 : 1;
            if (j2>=nphi) j2-=nphi;
            cube.v(supp+1+i,j+supp,k) += fct*cube(supp-1-i,j2+supp,k);
            cube.v(supp+ntheta-2-i, j+supp,k) += fct*cube(supp+ntheta+i,j2+supp,k);
            }
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

      // special treatment for poles
      for (size_t j=0,j2=nphi/2; j<nphi/2; ++j,++j2)
        for (size_t k=0; k<cube.shape(2); ++k)
          {
          double fct = (((k+1)/2)&1) ? -1 : 1;
          if (j2>=nphi) j2-=nphi;
          double tval = (cube(supp,j+supp,k) + fct*cube(supp,j2+supp,k));
          cube.v(supp,j+supp,k) = tval;
          cube.v(supp,j2+supp,k) = fct*tval;
          tval = (cube(supp+ntheta-1,j+supp,k) + fct*cube(supp+ntheta-1,j2+supp,k));
          cube.v(supp+ntheta-1,j+supp,k) = tval;
          cube.v(supp+ntheta-1,j2+supp,k) = fct*tval;
          }

392
393
394
395
396
397
398
399
400
401
      vector<double>lnorm(lmax+1);
      for (size_t i=0; i<=lmax; ++i)
        lnorm[i]=sqrt(4*pi/(2*i+1.));

      {
      auto m1 = cube.template subarray<2>({supp,supp,0},{ntheta,nphi,0});
      decorrect(m1,0);
      sharp_alm2map_adjoint(a1.Alms().vdata(), m1.data(), *ginfo, *ainfo, 0, nthreads);
      for (size_t m=0; m<=lmax; ++m)
        for (size_t l=m; l<=lmax; ++l)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
402
          slmT(l,m)=conj(a1(l,m))*blmT(l,0).real()*T(lnorm[l]);
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
      }

      for (size_t k=1; k<=kmax; ++k)
        {
        auto m1 = cube.template subarray<2>({supp,supp,2*k-1},{ntheta,nphi,0});
        auto m2 = cube.template subarray<2>({supp,supp,2*k  },{ntheta,nphi,0});
        decorrect(m1,k);
        decorrect(m2,k);

        sharp_alm2map_spin_adjoint(k, a1.Alms().vdata(), a2.Alms().vdata(), m1.data(),
          m2.data(), *ginfo, *ainfo, 0, nthreads);
        for (size_t m=0; m<=lmax; ++m)
          for (size_t l=m; l<=lmax; ++l)
            {
            if (l>=k)
              {
              auto tmp = -2.*conj(blmT(l,k))*T(lnorm[l]);
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
420
421
              slmT(l,m) += conj(a1(l,m))*tmp.real();
              slmT(l,m) -= conj(a2(l,m))*tmp.imag();
422
423
424
425
              }
            }
        }
      }
426
427
428
429
  };

template<typename T> class PyInterpolator: public Interpolator<T>
  {
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
430
431
432
433
434
435
436
  protected:
    using Interpolator<T>::lmax;
    using Interpolator<T>::kmax;
    using Interpolator<T>::interpolx;
    using Interpolator<T>::deinterpolx;
    using Interpolator<T>::getSlmx;

437
  public:
438
439
    PyInterpolator(const py::array &slmT, const py::array &blmT,
      int64_t lmax, int64_t kmax, double epsilon, int nthreads=0)
440
441
      : Interpolator<T>(Alm<complex<T>>(to_mav<complex<T>,1>(slmT), lmax, lmax),
                        Alm<complex<T>>(to_mav<complex<T>,1>(blmT), lmax, kmax),
Martin Reinecke's avatar
Martin Reinecke committed
442
                        epsilon, nthreads) {}
443
444
445
    PyInterpolator(int64_t lmax, int64_t kmax, double epsilon, int nthreads=0)
      : Interpolator<T>(lmax, kmax, epsilon, nthreads) {}
    py::array interpol(const py::array &ptg) const
446
447
448
449
450
451
452
      {
      auto ptg2 = to_mav<T,2>(ptg);
      auto res = make_Pyarr<double>({ptg2.shape(0)});
      auto res2 = to_mav<double,1>(res,true);
      interpolx(ptg2, res2);
      return res;
      }
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468

    void deinterpol(const py::array &ptg, const py::array &data)
      {
      auto ptg2 = to_mav<T,2>(ptg);
      auto data2 = to_mav<T,1>(data);
      deinterpolx(ptg2, data2);
      }
    py::array getSlm(const py::array &blmT_)
      {
      auto res = make_Pyarr<complex<T>>({Alm_Base::Num_Alms(lmax, lmax)});
      Alm<complex<T>> blmT(to_mav<complex<T>,1>(blmT_, false), lmax, kmax);
      auto slmT_=to_mav<complex<T>,1>(res, true);
      Alm<complex<T>> slmT(slmT_, lmax, lmax);
      getSlmx(blmT, slmT);
      return res;
      }
469
470
  };

471
#if 1
Martin Reinecke's avatar
Martin Reinecke committed
472
473
474
template<typename T> py::array pyrotate_alm(const py::array &alm_, int64_t lmax,
  double psi, double theta, double phi)
  {
475
476
477
478
479
480
481
  auto a1 = to_mav<complex<T>,1>(alm_);
  auto alm = make_Pyarr<complex<T>>({a1.shape(0)});
  auto a2 = to_mav<complex<T>,1>(alm,true);
  for (size_t i=0; i<a1.shape(0); ++i) a2.v(i)=a1(i);
  auto blah = Alm<complex<T>>(a2,lmax,lmax);
  rotate_alm(blah, psi, theta, phi);
  return alm;
Martin Reinecke's avatar
Martin Reinecke committed
482
483
484
  }
#endif

485
486
487
488
489
490
491
} // unnamed namespace

PYBIND11_MODULE(interpol_ng, m)
  {
  using namespace pybind11::literals;

  py::class_<PyInterpolator<double>> (m, "PyInterpolator")
Martin Reinecke's avatar
Martin Reinecke committed
492
493
    .def(py::init<const py::array &, const py::array &, int64_t, int64_t, double, int>(),
      "sky"_a, "beam"_a, "lmax"_a, "kmax"_a, "epsilon"_a, "nthreads"_a)
494
495
496
497
    .def(py::init<int64_t, int64_t, double, int>(),
      "lmax"_a, "kmax"_a, "epsilon"_a, "nthreads"_a)
    .def ("interpol", &PyInterpolator<double>::interpol, "ptg"_a)
    .def ("deinterpol", &PyInterpolator<double>::deinterpol, "ptg"_a, "data"_a)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
498
    .def ("getSlm", &PyInterpolator<double>::getSlm, "blmT"_a);
499
#if 1
Martin Reinecke's avatar
Martin Reinecke committed
500
501
502
  m.def("rotate_alm", &pyrotate_alm<double>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a,
    "phi"_a);
#endif
503
  }