interpol_ng.cc 18.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*
 *  Copyright (C) 2020 Max-Planck-Society
 *  Author: Martin Reinecke
 */

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <vector>
#include <complex>
#include "mr_util/math/constants.h"
#include "mr_util/math/gl_integrator.h"
#include "mr_util/math/es_kernel.h"
#include "mr_util/infra/mav.h"
#include "mr_util/sharp/sharp.h"
#include "mr_util/sharp/sharp_almhelpers.h"
#include "mr_util/sharp/sharp_geomhelpers.h"
#include "alm.h"
#include "mr_util/math/fft.h"
#include "mr_util/bindings/pybind_utils.h"
Martin Reinecke's avatar
Martin Reinecke committed
20
#include <iostream>
21
22
23
24
25
26
27
using namespace std;
using namespace mr;

namespace py = pybind11;

namespace {

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
28
29
constexpr double ofmin=1.5;

30
31
32
template<typename T> class Interpolator
  {
  protected:
33
    bool adjoint;
Martin Reinecke's avatar
Martin Reinecke committed
34
    size_t lmax, kmax, nphi0, ntheta0, nphi, ntheta;
Martin Reinecke's avatar
Martin Reinecke committed
35
    int nthreads;
Martin Reinecke's avatar
fix    
Martin Reinecke committed
36
37
    double ofactor;
    size_t supp;
38
39
40
    ES_Kernel kernel;
    mav<T,3> cube; // the data cube (theta, phi, 2*mbeam+1[, IQU])

41
    void correct(mav<T,2> &arr, int spin)
42
      {
43
      double sfct = (spin&1) ? -1 : 1;
44
      mav<T,2> tmp({nphi,nphi});
Martin Reinecke's avatar
Martin Reinecke committed
45
      fmav<T> ftmp(tmp);
Martin Reinecke's avatar
Martin Reinecke committed
46
      tmp.apply([](T &v){v=0.;});
Martin Reinecke's avatar
Martin Reinecke committed
47
      auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
Martin Reinecke's avatar
Martin Reinecke committed
48
49
50
51
      fmav<T> ftmp0(tmp0);
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) = arr(i,j);
52
      // extend to second half
Martin Reinecke's avatar
Martin Reinecke committed
53
      for (size_t i=1, i2=nphi0-1; i+1<ntheta0; ++i,--i2)
Martin Reinecke's avatar
Martin Reinecke committed
54
        for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
55
          {
Martin Reinecke's avatar
Martin Reinecke committed
56
          if (j2>=nphi0) j2-=nphi0;
Martin Reinecke's avatar
Martin Reinecke committed
57
          tmp0.v(i2,j) = sfct*tmp0(i,j2);
58
          }
Martin Reinecke's avatar
Martin Reinecke committed
59
      // FFT to frequency domain on minimal grid
60
      r2r_fftpack(ftmp0,ftmp0,{1,0},true,true,1./(nphi0*nphi0),nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
64
65
      for (size_t i=0; i<nphi0; ++i)
        {
        tmp0.v(i,nphi0-1)*=0.5;
        tmp0.v(nphi0-1,i)*=0.5;
        }
Martin Reinecke's avatar
Martin Reinecke committed
66
      auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
67
68
69
      for (size_t i=0; i<nphi0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
Martin Reinecke's avatar
Martin Reinecke committed
70
71
72
73
      r2r_fftpack(ftmp,ftmp,{0,1},false,false,1.,nthreads);
      for (size_t i=0; i<ntheta; ++i)
        for (size_t j=0; j<nphi; ++j)
          arr.v(i,j) = tmp(i,j);
74
      }
75
76
77
78
79
80
81
82
83
84
    void decorrect(mav<T,2> &arr, int spin)
      {
      double sfct = (spin&1) ? -1 : 1;
      mav<T,2> tmp({nphi,nphi});
      fmav<T> ftmp(tmp);

      for (size_t i=0; i<ntheta; ++i)
        for (size_t j=0; j<nphi; ++j)
          tmp.v(i,j) = arr(i,j);
      // extend to second half
Martin Reinecke's avatar
Martin Reinecke committed
85
      for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
86
87
88
89
        for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2)
          {
          if (j2>=nphi) j2-=nphi;
          tmp.v(i2,j) = sfct*tmp(i,j2);
Martin Reinecke's avatar
Martin Reinecke committed
90
//tmp.v(i2,j)=0.;
91
92
93
94
95
96
97
98
99
          }
      r2r_fftpack(ftmp,ftmp,{0,1},true,true,1.,nthreads);
      auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
      auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
      fmav<T> ftmp0(tmp0);
      for (size_t i=0; i<nphi0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
      // FFT to (theta, phi) domain on minimal grid
Martin Reinecke's avatar
Martin Reinecke committed
100
101
      r2r_fftpack(ftmp0,ftmp0,{1,0},false, false,1./(nphi0*nphi0),nthreads);
arr.apply([](T &v){v=0.;});
102
103
104
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          arr.v(i,j) = tmp0(i,j);
Martin Reinecke's avatar
Martin Reinecke committed
105
106
107
108
// adjoint of the extension to 2pi in theta
//       for (size_t i=1; i+1<ntheta0; ++i)
//         for (size_t j=0; j<nphi0; ++j)
//           arr.v(i,j)*=2;
109
      }
110
111
112

  public:
    Interpolator(const Alm<complex<T>> &slmT, const Alm<complex<T>> &blmT,
Martin Reinecke's avatar
Martin Reinecke committed
113
      double epsilon, int nthreads_)
114
115
      : adjoint(false),
        lmax(slmT.Lmax()),
116
        kmax(blmT.Mmax()),
Martin Reinecke's avatar
Martin Reinecke committed
117
118
        nphi0(2*good_size_real(lmax+1)),
        ntheta0(nphi0/2+1),
119
        nphi(max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
Martin Reinecke's avatar
Martin Reinecke committed
120
        ntheta(nphi/2+1),
Martin Reinecke's avatar
Martin Reinecke committed
121
        nthreads(nthreads_),
Martin Reinecke's avatar
Martin Reinecke committed
122
        ofactor(double(nphi)/(2*lmax+1)),
Martin Reinecke's avatar
fix    
Martin Reinecke committed
123
        supp(ES_Kernel::get_supp(epsilon, ofactor)),
Martin Reinecke's avatar
Martin Reinecke committed
124
        kernel(supp, ofactor, nthreads),
125
126
127
128
129
130
        cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1})
      {
      MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
      MR_assert(slmT.Mmax()==lmax, "Sky lmax must be equal to Sky mmax");
      MR_assert(blmT.Lmax()==lmax, "Sky and beam lmax must be equal");
      Alm<complex<T>> a1(lmax, lmax), a2(lmax,lmax);
Martin Reinecke's avatar
Martin Reinecke committed
131
      auto ginfo = sharp_make_cc_geom_info(ntheta0,nphi0,0.,cube.stride(1),cube.stride(0));
132
133
134
135
136
137
      auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);

      vector<double>lnorm(lmax+1);
      for (size_t i=0; i<=lmax; ++i)
        lnorm[i]=sqrt(4*pi/(2*i+1.));

138
139
140
141
142
143
144
      {
      for (size_t m=0; m<=lmax; ++m)
        for (size_t l=m; l<=lmax; ++l)
          a1(l,m) = slmT(l,m)*blmT(l,0).real()*T(lnorm[l]);
      auto m1 = cube.template subarray<2>({supp,supp,0},{ntheta,nphi,0});
      sharp_alm2map(a1.Alms().data(), m1.vdata(), *ginfo, *ainfo, 0, nthreads);
      correct(m1,0);
145
146
147
//  for (size_t i=1; i+1<ntheta; ++i)
//    for (size_t j=0; j<nphi; ++j)
//      m1.v(i,j)*=2;
148
149
      }
      for (size_t k=1; k<=kmax; ++k)
150
151
152
153
154
155
156
157
        {
        for (size_t m=0; m<=lmax; ++m)
          for (size_t l=m; l<=lmax; ++l)
            {
            if (l<k)
              a1(l,m)=a2(l,m)=0.;
            else
              {
158
              auto tmp = -2.*blmT(l,k)*T(lnorm[l]);
Martin Reinecke's avatar
Martin Reinecke committed
159
              a1(l,m) = slmT(l,m)*tmp.real();
160
              a2(l,m) = slmT(l,m)*tmp.imag();
161
162
              }
            }
163
164
165
166
        auto m1 = cube.template subarray<2>({supp,supp,2*k-1},{ntheta,nphi,0});
        auto m2 = cube.template subarray<2>({supp,supp,2*k  },{ntheta,nphi,0});
        sharp_alm2map_spin(k, a1.Alms().data(), a2.Alms().data(), m1.vdata(),
          m2.vdata(), *ginfo, *ainfo, 0, nthreads);
167
        correct(m1,k);
168
        correct(m2,k);
169
170
171
172
173
174
        }
      // fill border regions
      for (size_t i=0; i<supp; ++i)
        for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
175
            double fct = (((k+1)/2)&1) ? -1 : 1;
176
            if (j2>=nphi) j2-=nphi;
177
178
            cube.v(supp-1-i,j2+supp,k) = fct*cube(supp+1+i,j+supp,k);
            cube.v(supp+ntheta+i,j2+supp,k) = fct*cube(supp+ntheta-2-i, j+supp,k);
179
180
181
182
183
184
185
186
187
188
            }
      for (size_t i=0; i<ntheta+2*supp; ++i)
        for (size_t j=0; j<supp; ++j)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
            cube.v(i,j,k) = cube(i,j+nphi,k);
            cube.v(i,j+nphi+supp,k) = cube(i,j+supp,k);
            }
      }

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    Interpolator(size_t lmax_, size_t kmax_, double epsilon, int nthreads_)
      : adjoint(true),
        lmax(lmax_),
        kmax(kmax_),
        nphi0(2*good_size_real(lmax+1)),
        ntheta0(nphi0/2+1),
        nphi(max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
        ntheta(nphi/2+1),
        nthreads(nthreads_),
        ofactor(double(nphi)/(2*lmax+1)),
        supp(ES_Kernel::get_supp(epsilon, ofactor)),
        kernel(supp, ofactor, nthreads),
        cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1})
      {
      MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
      cube.apply([](T &v){v=0.;});
      }

207
208
    void interpolx (const mav<T,2> &ptg, mav<T,1> &res) const
      {
209
      MR_assert(!adjoint, "cannot be called in adjoint mode");
210
211
      MR_assert(ptg.shape(0)==res.shape(0), "dimension mismatch");
      MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
Martin Reinecke's avatar
Martin Reinecke committed
212
213
214
      double delta = 2./supp;
      double xdtheta = (ntheta-1)/pi,
             xdphi = nphi/(2*pi);
Martin Reinecke's avatar
Martin Reinecke committed
215
216
      vector<size_t> idx(ptg.shape(0));
      {
217
      // do some pre-sorting to improve cache use
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
218
219
220
221
222
223
      constexpr size_t cellsize=16;
      size_t nct = ntheta/cellsize+1,
             ncp = nphi/cellsize+1;
      vector<vector<size_t>> mapper(nct*ncp);
      for (size_t i=0; i<ptg.shape(0); ++i)
        {
224
225
        size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
               iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
226
227
        mapper[itheta*ncp+iphi].push_back(i);
        }
Martin Reinecke's avatar
Martin Reinecke committed
228
      size_t cnt=0;
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
229
      for (const auto &vec: mapper)
Martin Reinecke's avatar
Martin Reinecke committed
230
231
232
233
        for (auto i:vec)
          idx[cnt++] = i;
      }
      execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
234
        {
Martin Reinecke's avatar
Martin Reinecke committed
235
236
237
        vector<T> wt(supp), wp(supp);
        vector<T> psiarr(2*kmax+1);
        while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
238
          {
Martin Reinecke's avatar
Martin Reinecke committed
239
240
241
242
243
244
245
246
247
248
249
          size_t i=idx[ind];
          double f0=0.5*supp+ptg(i,0)*xdtheta;
          size_t i0 = size_t(f0+1.);
          for (size_t t=0; t<supp; ++t)
            wt[t] = kernel((t+i0-f0)*delta - 1);
          double f1=0.5*supp+ptg(i,1)*xdphi;
          size_t i1 = size_t(f1+1.);
          for (size_t t=0; t<supp; ++t)
            wp[t] = kernel((t+i1-f1)*delta - 1);
          double val=0;
          psiarr[0]=1.;
250
251
          double psi=ptg(i,2);
          double cpsi=cos(psi), spsi=sin(psi);
Martin Reinecke's avatar
Martin Reinecke committed
252
253
254
255
          double cnpsi=cpsi, snpsi=spsi;
          for (size_t l=1; l<=kmax; ++l)
            {
            psiarr[2*l-1]=cnpsi;
256
            psiarr[2*l]=snpsi;
Martin Reinecke's avatar
Martin Reinecke committed
257
258
259
260
261
262
263
264
265
            const double tmp = snpsi*cpsi + cnpsi*spsi;
            cnpsi=cnpsi*cpsi - snpsi*spsi;
            snpsi=tmp;
            }
          for (size_t j=0; j<supp; ++j)
            for (size_t k=0; k<supp; ++k)
              for (size_t l=0; l<2*kmax+1; ++l)
                val += cube(i0+j,i1+k,l)*wt[j]*wp[k]*psiarr[l];
          res.v(i) = val;
266
          }
Martin Reinecke's avatar
Martin Reinecke committed
267
        });
268
      }
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

    void deinterpolx (const mav<T,2> &ptg, const mav<T,1> &data)
      {
      MR_assert(adjoint, "can only be called in adjoint mode");
      MR_assert(ptg.shape(0)==data.shape(0), "dimension mismatch");
      MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
      double delta = 2./supp;
      double xdtheta = (ntheta-1)/pi,
             xdphi = nphi/(2*pi);
      vector<size_t> idx(ptg.shape(0));
      {
      // do some pre-sorting to improve cache use
      constexpr size_t cellsize=16;
      size_t nct = ntheta/cellsize+1,
             ncp = nphi/cellsize+1;
      vector<vector<size_t>> mapper(nct*ncp);
      for (size_t i=0; i<ptg.shape(0); ++i)
        {
        size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
               iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
        mapper[itheta*ncp+iphi].push_back(i);
        }
      size_t cnt=0;
      for (const auto &vec: mapper)
        for (auto i:vec)
          idx[cnt++] = i;
      }
      execStatic(idx.size(), 1, 0, [&](Scheduler &sched) // not parallel yet
        {
        vector<T> wt(supp), wp(supp);
        vector<T> psiarr(2*kmax+1);
        while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
          {
          size_t i=idx[ind];
          double f0=0.5*supp+ptg(i,0)*xdtheta;
          size_t i0 = size_t(f0+1.);
          for (size_t t=0; t<supp; ++t)
            wt[t] = kernel((t+i0-f0)*delta - 1);
          double f1=0.5*supp+ptg(i,1)*xdphi;
          size_t i1 = size_t(f1+1.);
          for (size_t t=0; t<supp; ++t)
            wp[t] = kernel((t+i1-f1)*delta - 1);
          double val=data(i);
          psiarr[0]=1.;
          double psi=ptg(i,2);
          double cpsi=cos(psi), spsi=sin(psi);
          double cnpsi=cpsi, snpsi=spsi;
          for (size_t l=1; l<=kmax; ++l)
            {
            psiarr[2*l-1]=cnpsi;
            psiarr[2*l]=snpsi;
            const double tmp = snpsi*cpsi + cnpsi*spsi;
            cnpsi=cnpsi*cpsi - snpsi*spsi;
            snpsi=tmp;
            }
          for (size_t j=0; j<supp; ++j)
            for (size_t k=0; k<supp; ++k)
              for (size_t l=0; l<2*kmax+1; ++l)
                cube.v(i0+j,i1+k,l) += val*wt[j]*wp[k]*psiarr[l];
          }
        });
      }
    void getSlmx (const Alm<complex<T>> &blmT, Alm<complex<T>> &slmT)
      {
      MR_assert(adjoint, "can only be called in adjoint mode");
      Alm<complex<T>> a1(lmax, lmax), a2(lmax,lmax);
      auto ginfo = sharp_make_cc_geom_info(ntheta0,nphi0,0.,cube.stride(1),cube.stride(0));
      auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);

      // move stuff from border regions onto the main grid
339
      for (size_t i=0; i<cube.shape(0); ++i)
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        for (size_t j=0; j<supp; ++j)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
            cube.v(i,j+nphi,k) += cube(i,j,k);
            cube.v(i,j+supp,k) += cube(i,j+nphi+supp,k);
            }
      for (size_t i=0; i<supp; ++i)
        for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
            double fct = (((k+1)/2)&1) ? -1 : 1;
            if (j2>=nphi) j2-=nphi;
            cube.v(supp+1+i,j+supp,k) += fct*cube(supp-1-i,j2+supp,k);
            cube.v(supp+ntheta-2-i, j+supp,k) += fct*cube(supp+ntheta+i,j2+supp,k);
            }
355
356
357
358
359
360
361
362
363
364
365
366
367
368
for (size_t k=0; k<cube.shape(2); ++k)
{
double fct = (((k+1)/2)&1) ? -1 : 1;
for (size_t j=0,j2=nphi/2; j<nphi/2; ++j,++j2)
  {
  if (j2>=nphi) j2-=nphi;
  double tval = (cube(supp,j+supp,k) + fct*cube(supp,j2+supp,k));
  cube.v(supp,j+supp,k) = tval;
  cube.v(supp,j2+supp,k) = fct*tval;
  tval = (cube(supp+ntheta-1,j+supp,k) + fct*cube(supp+ntheta-1,j2+supp,k));
  cube.v(supp+ntheta-1,j+supp,k) = tval;
  cube.v(supp+ntheta-1,j2+supp,k) = fct*tval;
  }
}
369
370
371
372
373
374
375
      vector<double>lnorm(lmax+1);
      for (size_t i=0; i<=lmax; ++i)
        lnorm[i]=sqrt(4*pi/(2*i+1.));

      {
      auto m1 = cube.template subarray<2>({supp,supp,0},{ntheta,nphi,0});
      decorrect(m1,0);
376
377
   for (size_t j=0; j<nphi0; ++j)
     {
378
379
     m1.v(0,j)*=0.5;
     m1.v(ntheta0-1,j)*=0.5;
380
     }
381
382
383
      sharp_alm2map_adjoint(a1.Alms().vdata(), m1.data(), *ginfo, *ainfo, 0, nthreads);
      for (size_t m=0; m<=lmax; ++m)
        for (size_t l=m; l<=lmax; ++l)
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
384
          slmT(l,m)=a1(l,m)*blmT(l,0).real()*T(lnorm[l]);
385
386
387
388
389
390
391
392
      }

      for (size_t k=1; k<=kmax; ++k)
        {
        auto m1 = cube.template subarray<2>({supp,supp,2*k-1},{ntheta,nphi,0});
        auto m2 = cube.template subarray<2>({supp,supp,2*k  },{ntheta,nphi,0});
        decorrect(m1,k);
        decorrect(m2,k);
393
394
   for (size_t j=0; j<nphi0; ++j)
     {
395
396
     m1.v(0,j)*=0.5;
     m1.v(ntheta0-1,j)*=0.5;
397
398
399
     }
   for (size_t j=0; j<nphi0; ++j)
     {
400
401
     m2.v(0,j)*=0.5;
     m2.v(ntheta0-1,j)*=0.5;
402
     }
403
404
405
406
407
408
409
410
411

        sharp_alm2map_spin_adjoint(k, a1.Alms().vdata(), a2.Alms().vdata(), m1.data(),
          m2.data(), *ginfo, *ainfo, 0, nthreads);
        for (size_t m=0; m<=lmax; ++m)
          for (size_t l=m; l<=lmax; ++l)
            {
            if (l>=k)
              {
              auto tmp = -2.*conj(blmT(l,k))*T(lnorm[l]);
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
412
413
              slmT(l,m) += a1(l,m)*tmp.real();
              slmT(l,m) -= a2(l,m)*tmp.imag();
414
415
416
417
              }
            }
        }
      }
418
419
420
421
422
  };

template<typename T> class PyInterpolator: public Interpolator<T>
  {
  public:
423
424
    PyInterpolator(const py::array &slmT, const py::array &blmT,
      int64_t lmax, int64_t kmax, double epsilon, int nthreads=0)
425
426
      : Interpolator<T>(Alm<complex<T>>(to_mav<complex<T>,1>(slmT), lmax, lmax),
                        Alm<complex<T>>(to_mav<complex<T>,1>(blmT), lmax, kmax),
Martin Reinecke's avatar
Martin Reinecke committed
427
                        epsilon, nthreads) {}
428
429
    PyInterpolator(int64_t lmax, int64_t kmax, double epsilon, int nthreads=0)
      : Interpolator<T>(lmax, kmax, epsilon, nthreads) {}
430
    using Interpolator<T>::interpolx;
431
432
433
434
    using Interpolator<T>::deinterpolx;
    using Interpolator<T>::getSlmx;
    using Interpolator<T>::lmax;
    using Interpolator<T>::kmax;
Martin Reinecke's avatar
Martin Reinecke committed
435
436
437
438
439
440
    using Interpolator<T>::nphi;
    using Interpolator<T>::ntheta;
    using Interpolator<T>::nphi0;
    using Interpolator<T>::ntheta0;
    using Interpolator<T>::correct;
    using Interpolator<T>::decorrect;
441
    py::array interpol(const py::array &ptg) const
442
443
444
445
446
447
448
      {
      auto ptg2 = to_mav<T,2>(ptg);
      auto res = make_Pyarr<double>({ptg2.shape(0)});
      auto res2 = to_mav<double,1>(res,true);
      interpolx(ptg2, res2);
      return res;
      }
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465

    void deinterpol(const py::array &ptg, const py::array &data)
      {
      auto ptg2 = to_mav<T,2>(ptg);
      auto data2 = to_mav<T,1>(data);
      deinterpolx(ptg2, data2);
      }
    py::array getSlm(const py::array &blmT_)
      {
      auto res = make_Pyarr<complex<T>>({Alm_Base::Num_Alms(lmax, lmax)});
      Alm<complex<T>> blmT(to_mav<complex<T>,1>(blmT_, false), lmax, kmax);
      auto slmT_=to_mav<complex<T>,1>(res, true);
slmT_.apply([](complex<T> &v){v=0;});
      Alm<complex<T>> slmT(slmT_, lmax, lmax);
      getSlmx(blmT, slmT);
      return res;
      }
Martin Reinecke's avatar
Martin Reinecke committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    py::array test_correct(const py::array &in, int spin)
      {
      auto in2 = to_mav<T,2>(in);
      MR_assert(in2.conformable({ntheta0, nphi0}), "bad input shape");
      auto res = make_Pyarr<T>({ntheta, nphi});
      auto res2 = to_mav<T,2>(res,true);
res2.apply([](T &v){v=0;});
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          res2.v(i,j) = in2(i,j);
      correct (res2, spin);
      return res;
      }
    py::array test_decorrect(const py::array &in, int spin)
      {
      auto in2 = to_mav<T,2>(in);
      MR_assert(in2.conformable({ntheta, nphi}), "bad input shape");
      auto tmp = mav<T,2>({ntheta, nphi});
      for (size_t i=0; i<ntheta; ++i)
        for (size_t j=0; j<nphi; ++j)
          tmp.v(i,j) = in2(i,j);
      decorrect (tmp, spin);
      auto res = make_Pyarr<T>({ntheta0, nphi0});
      auto res2 = to_mav<T,2>(res,true);
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          res2.v(i,j) = tmp(i,j);
      return res;
      }
    int Nphi0() const { return nphi0; }
    int Ntheta0() const { return ntheta0; }
    int Nphi() const { return nphi; }
    int Ntheta() const { return ntheta; }
499
500
  };

501
#if 1
Martin Reinecke's avatar
Martin Reinecke committed
502
503
504
template<typename T> py::array pyrotate_alm(const py::array &alm_, int64_t lmax,
  double psi, double theta, double phi)
  {
505
506
507
508
509
510
511
  auto a1 = to_mav<complex<T>,1>(alm_);
  auto alm = make_Pyarr<complex<T>>({a1.shape(0)});
  auto a2 = to_mav<complex<T>,1>(alm,true);
  for (size_t i=0; i<a1.shape(0); ++i) a2.v(i)=a1(i);
  auto blah = Alm<complex<T>>(a2,lmax,lmax);
  rotate_alm(blah, psi, theta, phi);
  return alm;
Martin Reinecke's avatar
Martin Reinecke committed
512
513
514
  }
#endif

515
516
517
518
519
520
521
} // unnamed namespace

PYBIND11_MODULE(interpol_ng, m)
  {
  using namespace pybind11::literals;

  py::class_<PyInterpolator<double>> (m, "PyInterpolator")
Martin Reinecke's avatar
Martin Reinecke committed
522
523
    .def(py::init<const py::array &, const py::array &, int64_t, int64_t, double, int>(),
      "sky"_a, "beam"_a, "lmax"_a, "kmax"_a, "epsilon"_a, "nthreads"_a)
524
525
526
527
    .def(py::init<int64_t, int64_t, double, int>(),
      "lmax"_a, "kmax"_a, "epsilon"_a, "nthreads"_a)
    .def ("interpol", &PyInterpolator<double>::interpol, "ptg"_a)
    .def ("deinterpol", &PyInterpolator<double>::deinterpol, "ptg"_a, "data"_a)
Martin Reinecke's avatar
Martin Reinecke committed
528
529
530
531
532
533
534
    .def ("getSlm", &PyInterpolator<double>::getSlm, "blmT"_a)
    .def ("test_correct", &PyInterpolator<double>::test_correct, "in"_a, "spin"_a)
    .def ("test_decorrect", &PyInterpolator<double>::test_decorrect, "in"_a, "spin"_a)
    .def ("Nphi", &PyInterpolator<double>::Nphi)
    .def ("Ntheta", &PyInterpolator<double>::Ntheta)
    .def ("Nphi0", &PyInterpolator<double>::Nphi0)
    .def ("Ntheta0", &PyInterpolator<double>::Ntheta0);
535
#if 1
Martin Reinecke's avatar
Martin Reinecke committed
536
537
538
  m.def("rotate_alm", &pyrotate_alm<double>, "alm"_a, "lmax"_a, "psi"_a, "theta"_a,
    "phi"_a);
#endif
539
  }