interpol_ng.h 22.6 KB
Newer Older
1
2
3
4
5
/*
 *  Copyright (C) 2020 Max-Planck-Society
 *  Author: Martin Reinecke
 */

Martin Reinecke's avatar
Martin Reinecke committed
6
7
8
#ifndef MRUTIL_INTERPOL_NG_H
#define MRUTIL_INTERPOL_NG_H

9
10
#include <vector>
#include <complex>
Martin Reinecke's avatar
Martin Reinecke committed
11
#include <cmath>
12
13
14
15
16
17
18
19
20
21
#include "mr_util/math/constants.h"
#include "mr_util/math/gl_integrator.h"
#include "mr_util/math/es_kernel.h"
#include "mr_util/infra/mav.h"
#include "mr_util/sharp/sharp.h"
#include "mr_util/sharp/sharp_almhelpers.h"
#include "mr_util/sharp/sharp_geomhelpers.h"
#include "alm.h"
#include "mr_util/math/fft.h"
#include "mr_util/bindings/pybind_utils.h"
Martin Reinecke's avatar
Martin Reinecke committed
22

Martin Reinecke's avatar
Martin Reinecke committed
23
namespace mr {
24

Martin Reinecke's avatar
Martin Reinecke committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#if 0
namespace detail_fft {

using std::vector;

template<typename T, typename T0> aligned_array<T> alloc_tmp_conv
  (const fmav_info &info, size_t axis, size_t len)
  {
  auto othersize = info.size()/info.shape(axis);
  constexpr auto vlen = native_simd<T0>::size();
  auto tmpsize = len*((othersize>=vlen) ? vlen : 1);
  return aligned_array<T>(tmpsize);
  }

template<typename Tplan, typename T, typename T0, typename Exec>
MRUTIL_NOINLINE void general_convolve(const fmav<T> &in, fmav<T> &out,
  const size_t axis, const vector<T0> &kernel, size_t nthreads,
  const Exec &exec, const bool allow_inplace=true)
  {
  std::shared_ptr<Tplan> plan1, plan2;

  size_t l_in=in.shape(axis), l_out=out.shape(axis);
  size_t l_min=std::min(l_in, l_out), l_max=std::max(l_in, l_out);
  MR_assert(kernel.size()==l_min/2+1, "bad kernel size");
  plan1 = get_plan<Tplan>(l_in);
  plan2 = get_plan<Tplan>(l_out);

  execParallel(
    util::thread_count(nthreads, in, axis, native_simd<T0>::size()),
    [&](Scheduler &sched) {
      constexpr auto vlen = native_simd<T0>::size();
      auto storage = alloc_tmp_conv<T,T0>(in, axis, l_max); //FIXME!
      multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
#ifndef MRUTIL_NO_SIMD
      if (vlen>1)
        while (it.remaining()>=vlen)
          {
          it.advance(vlen);
          auto tdatav = reinterpret_cast<add_vec_t<T> *>(storage.data());
          exec(it, in, out, tdatav, *plan1, *plan2, kernel);
          }
#endif
      while (it.remaining()>0)
        {
        it.advance(1);
        auto buf = allow_inplace && it.stride_out() == 1 ?
          &out.vraw(it.oofs(0)) : reinterpret_cast<T *>(storage.data());
        exec(it, in, out, buf, *plan1, *plan2, kernel);
        }
    });  // end of parallel region
  }

struct ExecConvR1
  {
  template <typename T0, typename T, size_t vlen> void operator() (
    const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out,
    T * buf, const pocketfft_r<T0> &plan1, const pocketfft_r<T0> &plan2,
    const vector<T0> &kernel) const
    {
    size_t l_in = plan1.length(),
           l_out = plan2.length(),
           l_min = std::min(l_in, l_out);
    copy_input(it, in, buf);
    plan1.exec(buf, T0(1), true);
    buf[0] *= kernel[0];
    for (size_t i=1; i<l_min; ++i)
      { buf[2*i-1]*=kernel[i]; buf[2*i] *=kernel[i]; }
    for (size_t i=l_in; i<l_out; ++i) buf[i] = T(0);
    plan2.exec(buf, T0(1), false);
    copy_output(it, buf, out);
    }
  };

template<typename T> void convolve_1d(const fmav<T> &in,
  fmav<T> &out, size_t axis, const vector<T> &kernel, size_t nthreads=1)
  {
//  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  MR_assert(axis<in.ndim(), "bad axis number");
  MR_assert(in.ndim()==out.ndim(), "dimensionality mismatch");
  if (in.data()==out.data())
    MR_assert(in.strides()==out.strides(), "strides mismatch");
  for (size_t i=0; i<in.ndim(); ++i)
    if (i!=axis)
      MR_assert(in.shape(i)==out.shape(i), "shape mismatch");
  if (in.size()==0) return;
  general_convolve<pocketfft_r<T>>(in, out, axis, kernel, nthreads,
    ExecConvR1());
  }

}
#endif
Martin Reinecke's avatar
Martin Reinecke committed
116
namespace detail_interpol_ng {
117

Martin Reinecke's avatar
Martin Reinecke committed
118
using namespace std;
119
120
121
122

template<typename T> class Interpolator
  {
  protected:
123
    bool adjoint;
Martin Reinecke's avatar
Martin Reinecke committed
124
    size_t lmax, kmax, nphi0, ntheta0, nphi, ntheta;
Martin Reinecke's avatar
Martin Reinecke committed
125
    int nthreads;
126
    T ofactor;
Martin Reinecke's avatar
fix    
Martin Reinecke committed
127
    size_t supp;
128
    ES_Kernel kernel;
129
130
    size_t ncomp;
    mav<T,4> cube; // the data cube (theta, phi, 2*mbeam+1, TGC)
131

132
    void correct(mav<T,2> &arr, int spin)
133
      {
134
      T sfct = (spin&1) ? -1 : 1;
135
      mav<T,2> tmp({nphi,nphi});
Martin Reinecke's avatar
Martin Reinecke committed
136
      tmp.apply([](T &v){v=0.;});
Martin Reinecke's avatar
Martin Reinecke committed
137
      auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
Martin Reinecke's avatar
Martin Reinecke committed
138
139
140
141
      fmav<T> ftmp0(tmp0);
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) = arr(i,j);
142
      // extend to second half
Martin Reinecke's avatar
Martin Reinecke committed
143
// FIXME: merge with loop above to avoid edundant memory reads.
Martin Reinecke's avatar
Martin Reinecke committed
144
      for (size_t i=1, i2=nphi0-1; i+1<ntheta0; ++i,--i2)
Martin Reinecke's avatar
Martin Reinecke committed
145
        for (size_t j=0,j2=nphi0/2; j<nphi0; ++j,++j2)
146
          {
Martin Reinecke's avatar
Martin Reinecke committed
147
          if (j2>=nphi0) j2-=nphi0;
Martin Reinecke's avatar
Martin Reinecke committed
148
          tmp0.v(i2,j) = sfct*tmp0(i,j2);
149
          }
Martin Reinecke's avatar
Martin Reinecke committed
150
      // FFT to frequency domain on minimal grid
Martin Reinecke's avatar
Martin Reinecke committed
151
// one bad FFT axis
152
      r2r_fftpack(ftmp0,ftmp0,{0,1},true,true,T(1./(nphi0*nphi0)),nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
153

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
154
      // correct amplitude at Nyquist frequency
Martin Reinecke's avatar
Martin Reinecke committed
155
156
157
158
159
      for (size_t i=0; i<nphi0; ++i)
        {
        tmp0.v(i,nphi0-1)*=0.5;
        tmp0.v(nphi0-1,i)*=0.5;
        }
Martin Reinecke's avatar
Martin Reinecke committed
160
      auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
161
162
163
      for (size_t i=0; i<nphi0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
164
165
166
      auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
      fmav<T> ftmp1(tmp1);
      // zero-padded FFT in theta direction
Martin Reinecke's avatar
Martin Reinecke committed
167
// one bad FFT axis
168
      r2r_fftpack(ftmp1,ftmp1,{0},false,false,T(1),nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
169

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
170
171
172
173
      auto tmp2=tmp.template subarray<2>({0,0},{ntheta, nphi});
      fmav<T> ftmp2(tmp2);
      fmav<T> farr(arr);
      // zero-padded FFT in phi direction
174
      r2r_fftpack(ftmp2,farr,{1},false,false,T(1),nthreads);
175
      }
176
177
    void decorrect(mav<T,2> &arr, int spin)
      {
178
      T sfct = (spin&1) ? -1 : 1;
179
180
181
182
183
184
185
      mav<T,2> tmp({nphi,nphi});
      fmav<T> ftmp(tmp);

      for (size_t i=0; i<ntheta; ++i)
        for (size_t j=0; j<nphi; ++j)
          tmp.v(i,j) = arr(i,j);
      // extend to second half
Martin Reinecke's avatar
Martin Reinecke committed
186
      for (size_t i=1, i2=nphi-1; i+1<ntheta; ++i,--i2)
187
188
189
190
191
        for (size_t j=0,j2=nphi/2; j<nphi; ++j,++j2)
          {
          if (j2>=nphi) j2-=nphi;
          tmp.v(i2,j) = sfct*tmp(i,j2);
          }
192
      r2r_fftpack(ftmp,ftmp,{1},true,true,T(1),nthreads);
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
193
194
      auto tmp1=tmp.template subarray<2>({0,0},{nphi, nphi0});
      fmav<T> ftmp1(tmp1);
195
      r2r_fftpack(ftmp1,ftmp1,{0},true,true,T(1),nthreads);
196
197
198
199
200
201
202
      auto fct = kernel.correction_factors(nphi, nphi0/2+1, nthreads);
      auto tmp0=tmp.template subarray<2>({0,0},{nphi0, nphi0});
      fmav<T> ftmp0(tmp0);
      for (size_t i=0; i<nphi0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          tmp0.v(i,j) *= fct[(i+1)/2] * fct[(j+1)/2];
      // FFT to (theta, phi) domain on minimal grid
203
      r2r_fftpack(ftmp0,ftmp0,{0,1},false, false,T(1./(nphi0*nphi0)),nthreads);
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
204
205
206
207
208
      for (size_t j=0; j<nphi0; ++j)
        {
        tmp0.v(0,j)*=0.5;
        tmp0.v(ntheta0-1,j)*=0.5;
        }
209
210
211
212
      for (size_t i=0; i<ntheta0; ++i)
        for (size_t j=0; j<nphi0; ++j)
          arr.v(i,j) = tmp0(i,j);
      }
213

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
214
215
216
217
218
219
220
221
222
223
224
    vector<size_t> getIdx(const mav<T,2> &ptg) const
      {
      vector<size_t> idx(ptg.shape(0));
      constexpr size_t cellsize=16;
      size_t nct = ntheta/cellsize+1,
             ncp = nphi/cellsize+1;
      vector<vector<size_t>> mapper(nct*ncp);
      for (size_t i=0; i<ptg.shape(0); ++i)
        {
        size_t itheta=min(nct-1,size_t(ptg(i,0)/pi*nct)),
               iphi=min(ncp-1,size_t(ptg(i,1)/(2*pi)*ncp));
225
//        MR_assert((itheta<nct)&&(iphi<ncp), "oops");
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
226
227
228
229
230
231
232
233
234
        mapper[itheta*ncp+iphi].push_back(i);
        }
      size_t cnt=0;
      for (const auto &vec: mapper)
        for (auto i:vec)
          idx[cnt++] = i;
      return idx;
      }

235
  public:
236
237
    Interpolator(const vector<Alm<complex<T>>> &slm,
                 const vector<Alm<complex<T>>> &blm,
238
                 bool separate, T epsilon, T ofmin, int nthreads_)
239
      : adjoint(false),
240
241
        lmax(slm.at(0).Lmax()),
        kmax(blm.at(0).Mmax()),
Martin Reinecke's avatar
Martin Reinecke committed
242
243
        nphi0(2*good_size_real(lmax+1)),
        ntheta0(nphi0/2+1),
Martin Reinecke's avatar
Martin Reinecke committed
244
        nphi(std::max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
Martin Reinecke's avatar
Martin Reinecke committed
245
        ntheta(nphi/2+1),
Martin Reinecke's avatar
Martin Reinecke committed
246
        nthreads(nthreads_),
247
        ofactor(T(nphi)/(2*lmax+1)),
Martin Reinecke's avatar
fix    
Martin Reinecke committed
248
        supp(ES_Kernel::get_supp(epsilon, ofactor)),
Martin Reinecke's avatar
Martin Reinecke committed
249
        kernel(supp, ofactor, nthreads),
250
251
        ncomp(separate ? slm.size() : 1),
        cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1, ncomp})
252
      {
253
      MR_assert((ncomp==1)||(ncomp==3), "currently only 1 or 3 components allowed");
254
255
256
257
258
259
260
261
262
      MR_assert(slm.size()==blm.size(), "inconsistent slm and blm vectors");
      for (size_t i=0; i<slm.size(); ++i)
        {
        MR_assert(slm[i].Lmax()==lmax, "inconsistent Sky lmax");
        MR_assert(slm[i].Mmax()==lmax, "Sky lmax must be equal to Sky mmax");
        MR_assert(blm[i].Lmax()==lmax, "Sky and beam lmax must be equal");
        MR_assert(blm[i].Mmax()==kmax, "Inconcistent beam mmax");
        }

263
264
      MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
      Alm<complex<T>> a1(lmax, lmax), a2(lmax,lmax);
Martin Reinecke's avatar
Martin Reinecke committed
265
      auto ginfo = sharp_make_cc_geom_info(ntheta0,nphi0,0.,cube.stride(1),cube.stride(0));
266
267
      auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);

268
      vector<T>lnorm(lmax+1);
269
      for (size_t i=0; i<=lmax; ++i)
Martin Reinecke's avatar
Martin Reinecke committed
270
        lnorm[i]=std::sqrt(4*pi/(2*i+1.));
271

272
      for (size_t icomp=0; icomp<ncomp; ++icomp)
273
274
275
276
        {
        for (size_t m=0; m<=lmax; ++m)
          for (size_t l=m; l<=lmax; ++l)
            {
277
278
            if (separate)
              a1(l,m) = slm[icomp](l,m)*blm[icomp](l,0).real()*T(lnorm[l]);
279
280
            else
              {
281
282
283
              a1(l,m) = 0;
              for (size_t j=0; j<slm.size(); ++j)
                a1(l,m) += slm[j](l,m)*blm[j](l,0).real()*T(lnorm[l]);
284
285
              }
            }
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        auto m1 = cube.template subarray<2>({supp,supp,0,icomp},{ntheta,nphi,0,0});
        sharp_alm2map(a1.Alms().data(), m1.vdata(), *ginfo, *ainfo, 0, nthreads);
        correct(m1,0);

        for (size_t k=1; k<=kmax; ++k)
          {
          for (size_t m=0; m<=lmax; ++m)
            for (size_t l=m; l<=lmax; ++l)
              {
              if (l<k)
                a1(l,m)=a2(l,m)=0.;
              else
                {
                if (separate)
                  {
301
                  auto tmp = blm[icomp](l,k)*T(-2*lnorm[l]);
302
303
304
305
306
307
308
309
                  a1(l,m) = slm[icomp](l,m)*tmp.real();
                  a2(l,m) = slm[icomp](l,m)*tmp.imag();
                  }
                else
                  {
                  a1(l,m) = a2(l,m) = 0;
                  for (size_t j=0; j<slm.size(); ++j)
                    {
310
                    auto tmp = blm[j](l,k)*T(-2*lnorm[l]);
311
312
313
314
                    a1(l,m) += slm[j](l,m)*tmp.real();
                    a2(l,m) += slm[j](l,m)*tmp.imag();
                    }
                  }
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
315
                }
316
              }
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
317
318
319
320
321
322
          auto m1 = cube.template subarray<2>({supp,supp,2*k-1,icomp},{ntheta,nphi,0,0});
          auto m2 = cube.template subarray<2>({supp,supp,2*k  ,icomp},{ntheta,nphi,0,0});
          sharp_alm2map_spin(k, a1.Alms().data(), a2.Alms().data(), m1.vdata(),
            m2.vdata(), *ginfo, *ainfo, 0, nthreads);
          correct(m1,k);
          correct(m2,k);
323
          }
324
        }
325

326
327
328
329
330
      // fill border regions
      for (size_t i=0; i<supp; ++i)
        for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
331
            T fct = (((k+1)/2)&1) ? -1 : 1;
332
            if (j2>=nphi) j2-=nphi;
333
334
335
336
337
            for (size_t l=0; l<cube.shape(3); ++l)
              {
              cube.v(supp-1-i,j2+supp,k,l) = fct*cube(supp+1+i,j+supp,k,l);
              cube.v(supp+ntheta+i,j2+supp,k,l) = fct*cube(supp+ntheta-2-i,j+supp,k,l);
              }
338
339
340
341
            }
      for (size_t i=0; i<ntheta+2*supp; ++i)
        for (size_t j=0; j<supp; ++j)
          for (size_t k=0; k<cube.shape(2); ++k)
342
            for (size_t l=0; l<cube.shape(3); ++l)
343
            {
344
345
            cube.v(i,j,k,l) = cube(i,j+nphi,k,l);
            cube.v(i,j+nphi+supp,k,l) = cube(i,j+supp,k,l);
346
347
348
            }
      }

349
    Interpolator(size_t lmax_, size_t kmax_, size_t ncomp_, T epsilon, T ofmin, int nthreads_)
350
351
352
353
354
      : adjoint(true),
        lmax(lmax_),
        kmax(kmax_),
        nphi0(2*good_size_real(lmax+1)),
        ntheta0(nphi0/2+1),
Martin Reinecke's avatar
Martin Reinecke committed
355
        nphi(std::max<size_t>(20,2*good_size_real(size_t((2*lmax+1)*ofmin/2.)))),
356
357
        ntheta(nphi/2+1),
        nthreads(nthreads_),
358
        ofactor(T(nphi)/(2*lmax+1)),
359
360
        supp(ES_Kernel::get_supp(epsilon, ofactor)),
        kernel(supp, ofactor, nthreads),
361
362
        ncomp(ncomp_),
        cube({ntheta+2*supp, nphi+2*supp, 2*kmax+1, ncomp_})
363
      {
364
      MR_assert((ncomp==1)||(ncomp==3), "currently only 1 or 3 components allowed");
365
366
367
368
      MR_assert((supp<=ntheta) && (supp<=nphi), "support too large!");
      cube.apply([](T &v){v=0.;});
      }

369
    void interpol (const mav<T,2> &ptg, mav<T,2> &res) const
370
      {
371
      MR_assert(!adjoint, "cannot be called in adjoint mode");
372
373
      MR_assert(ptg.shape(0)==res.shape(0), "dimension mismatch");
      MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
374
      MR_assert(res.shape(1)==ncomp, "# of components mismatch");
375
376
377
      T delta = T(2)/supp;
      T xdtheta = T((ntheta-1)/pi),
        xdphi = T(nphi/(2*pi));
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
378
      auto idx = getIdx(ptg);
Martin Reinecke's avatar
Martin Reinecke committed
379
      execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
380
        {
Martin Reinecke's avatar
Martin Reinecke committed
381
382
383
        vector<T> wt(supp), wp(supp);
        vector<T> psiarr(2*kmax+1);
        while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
384
          {
Martin Reinecke's avatar
Martin Reinecke committed
385
          size_t i=idx[ind];
386
387
          T f0=T(0.5*supp+ptg(i,0)*xdtheta);
          size_t i0 = size_t(f0+T(1));
Martin Reinecke's avatar
Martin Reinecke committed
388
389
          for (size_t t=0; t<supp; ++t)
            wt[t] = kernel((t+i0-f0)*delta - 1);
390
          T f1=T(0.5)*supp+ptg(i,1)*xdphi;
Martin Reinecke's avatar
Martin Reinecke committed
391
392
393
394
          size_t i1 = size_t(f1+1.);
          for (size_t t=0; t<supp; ++t)
            wp[t] = kernel((t+i1-f1)*delta - 1);
          psiarr[0]=1.;
395
396
          double psi=ptg(i,2);
          double cpsi=cos(psi), spsi=sin(psi);
Martin Reinecke's avatar
Martin Reinecke committed
397
398
399
          double cnpsi=cpsi, snpsi=spsi;
          for (size_t l=1; l<=kmax; ++l)
            {
400
401
            psiarr[2*l-1]=T(cnpsi);
            psiarr[2*l]=T(snpsi);
Martin Reinecke's avatar
Martin Reinecke committed
402
403
404
405
            const double tmp = snpsi*cpsi + cnpsi*spsi;
            cnpsi=cnpsi*cpsi - snpsi*spsi;
            snpsi=tmp;
            }
406
407
          if (ncomp==1)
            {
408
            T vv=0;
409
410
411
            for (size_t j=0; j<supp; ++j)
              for (size_t k=0; k<supp; ++k)
                for (size_t l=0; l<2*kmax+1; ++l)
412
                  vv += cube(i0+j,i1+k,l,0)*wt[j]*wp[k]*psiarr[l];
413
414
415
416
            res.v(i,0) = vv;
            }
          else // ncomp==3
            {
417
            T v0=0., v1=0., v2=0.;
418
419
420
421
            for (size_t j=0; j<supp; ++j)
              for (size_t k=0; k<supp; ++k)
                for (size_t l=0; l<2*kmax+1; ++l)
                  {
422
                  auto tmp = wt[j]*wp[k]*psiarr[l];
423
424
425
426
427
428
429
430
                  v0 += cube(i0+j,i1+k,l,0)*tmp;
                  v1 += cube(i0+j,i1+k,l,1)*tmp;
                  v2 += cube(i0+j,i1+k,l,2)*tmp;
                  }
            res.v(i,0) = v0;
            res.v(i,1) = v1;
            res.v(i,2) = v2;
            }
431
          }
Martin Reinecke's avatar
Martin Reinecke committed
432
        });
433
      }
434

435
436
437
    size_t support() const
      { return supp; }

438
    void deinterpol (const mav<T,2> &ptg, const mav<T,2> &data)
439
440
441
442
      {
      MR_assert(adjoint, "can only be called in adjoint mode");
      MR_assert(ptg.shape(0)==data.shape(0), "dimension mismatch");
      MR_assert(ptg.shape(1)==3, "second dimension must have length 3");
443
      MR_assert(data.shape(1)==ncomp, "# of components mismatch");
444
445
446
      T delta = T(2)/supp;
      T xdtheta = T((ntheta-1)/pi),
        xdphi = T(nphi/(2*pi));
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
447
      auto idx = getIdx(ptg);
448
449
450
451
452
453
454

      constexpr size_t cellsize=16;
      size_t nct = ntheta/cellsize+5,
             ncp = nphi/cellsize+5;
      mav<std::mutex,2> locks({nct,ncp});

      execStatic(idx.size(), nthreads, 0, [&](Scheduler &sched)
455
        {
456
        size_t b_theta=99999999999999, b_phi=9999999999999999;
457
458
459
460
461
        vector<T> wt(supp), wp(supp);
        vector<T> psiarr(2*kmax+1);
        while (auto rng=sched.getNext()) for(auto ind=rng.lo; ind<rng.hi; ++ind)
          {
          size_t i=idx[ind];
462
          T f0=0.5*supp+ptg(i,0)*xdtheta;
463
464
465
          size_t i0 = size_t(f0+1.);
          for (size_t t=0; t<supp; ++t)
            wt[t] = kernel((t+i0-f0)*delta - 1);
466
          T f1=0.5*supp+ptg(i,1)*xdphi;
467
468
469
470
471
472
473
474
475
          size_t i1 = size_t(f1+1.);
          for (size_t t=0; t<supp; ++t)
            wp[t] = kernel((t+i1-f1)*delta - 1);
          psiarr[0]=1.;
          double psi=ptg(i,2);
          double cpsi=cos(psi), spsi=sin(psi);
          double cnpsi=cpsi, snpsi=spsi;
          for (size_t l=1; l<=kmax; ++l)
            {
476
477
            psiarr[2*l-1]=T(cnpsi);
            psiarr[2*l]=T(snpsi);
478
479
480
481
            const double tmp = snpsi*cpsi + cnpsi*spsi;
            cnpsi=cnpsi*cpsi - snpsi*spsi;
            snpsi=tmp;
            }
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
          size_t b_theta_new = i0/cellsize,
                 b_phi_new = i1/cellsize;
          if ((b_theta_new!=b_theta) || (b_phi_new!=b_phi))
            {
            if (b_theta<locks.shape(0))  // unlock
              {
              locks.v(b_theta,b_phi).unlock();
              locks.v(b_theta,b_phi+1).unlock();
              locks.v(b_theta+1,b_phi).unlock();
              locks.v(b_theta+1,b_phi+1).unlock();
              }
            b_theta = b_theta_new;
            b_phi = b_phi_new;
            locks.v(b_theta,b_phi).lock();
            locks.v(b_theta,b_phi+1).lock();
            locks.v(b_theta+1,b_phi).lock();
            locks.v(b_theta+1,b_phi+1).lock();
            }
500
501
          if (ncomp==1)
            {
502
            T val = data(i,0);
503
504
505
506
507
508
509
            for (size_t j=0; j<supp; ++j)
              for (size_t k=0; k<supp; ++k)
                for (size_t l=0; l<2*kmax+1; ++l)
                  cube.v(i0+j,i1+k,l,0) += val*wt[j]*wp[k]*psiarr[l];
            }
          else // ncomp==3
            {
510
            T v0=data(i,0), v1=data(i,1), v2=data(i,2);
511
512
513
            for (size_t j=0; j<supp; ++j)
              for (size_t k=0; k<supp; ++k)
                {
514
                T t0 = wt[j]*wp[k];
515
516
                for (size_t l=0; l<2*kmax+1; ++l)
                  {
517
                  T tmp = t0*psiarr[l];
518
519
520
521
522
523
                  cube.v(i0+j,i1+k,l,0) += v0*tmp;
                  cube.v(i0+j,i1+k,l,1) += v1*tmp;
                  cube.v(i0+j,i1+k,l,2) += v2*tmp;
                  }
                }
            }
524
          }
525
526
527
528
529
530
531
        if (b_theta<locks.shape(0))  // unlock
          {
          locks.v(b_theta,b_phi).unlock();
          locks.v(b_theta,b_phi+1).unlock();
          locks.v(b_theta+1,b_phi).unlock();
          locks.v(b_theta+1,b_phi+1).unlock();
          }
532
533
        });
      }
534
    void getSlm (const vector<Alm<complex<T>>> &blm, vector<Alm<complex<T>>> &slm)
535
536
      {
      MR_assert(adjoint, "can only be called in adjoint mode");
537
538
      MR_assert((blm.size()==ncomp) || (ncomp==1), "incorrect number of beam a_lm sets");
      MR_assert((slm.size()==ncomp) || (ncomp==1), "incorrect number of sky a_lm sets");
539
540
541
542
543
      Alm<complex<T>> a1(lmax, lmax), a2(lmax,lmax);
      auto ginfo = sharp_make_cc_geom_info(ntheta0,nphi0,0.,cube.stride(1),cube.stride(0));
      auto ainfo = sharp_make_triangular_alm_info(lmax,lmax,1);

      // move stuff from border regions onto the main grid
544
      for (size_t i=0; i<cube.shape(0); ++i)
545
546
        for (size_t j=0; j<supp; ++j)
          for (size_t k=0; k<cube.shape(2); ++k)
547
548
549
550
551
            for (size_t l=0; l<cube.shape(3); ++l)
              {
              cube.v(i,j+nphi,k,l) += cube(i,j,k,l);
              cube.v(i,j+supp,k,l) += cube(i,j+nphi+supp,k,l);
              }
552
553
554
555
      for (size_t i=0; i<supp; ++i)
        for (size_t j=0, j2=nphi/2; j<nphi; ++j,++j2)
          for (size_t k=0; k<cube.shape(2); ++k)
            {
556
            T fct = (((k+1)/2)&1) ? -1 : 1;
557
            if (j2>=nphi) j2-=nphi;
558
559
560
561
562
            for (size_t l=0; l<cube.shape(3); ++l)
              {
              cube.v(supp+1+i,j+supp,k,l) += fct*cube(supp-1-i,j2+supp,k,l);
              cube.v(supp+ntheta-2-i, j+supp,k,l) += fct*cube(supp+ntheta+i,j2+supp,k,l);
              }
563
            }
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
564
565
566
567

      // special treatment for poles
      for (size_t j=0,j2=nphi/2; j<nphi/2; ++j,++j2)
        for (size_t k=0; k<cube.shape(2); ++k)
568
569
          for (size_t l=0; l<cube.shape(3); ++l)
            {
570
            T fct = (((k+1)/2)&1) ? -1 : 1;
571
            if (j2>=nphi) j2-=nphi;
572
            T tval = (cube(supp,j+supp,k,l) + fct*cube(supp,j2+supp,k,l));
573
574
575
576
577
578
            cube.v(supp,j+supp,k,l) = tval;
            cube.v(supp,j2+supp,k,l) = fct*tval;
            tval = (cube(supp+ntheta-1,j+supp,k,l) + fct*cube(supp+ntheta-1,j2+supp,k,l));
            cube.v(supp+ntheta-1,j+supp,k,l) = tval;
            cube.v(supp+ntheta-1,j2+supp,k,l) = fct*tval;
            }
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
579

580
      vector<T>lnorm(lmax+1);
581
      for (size_t i=0; i<=lmax; ++i)
Martin Reinecke's avatar
Martin Reinecke committed
582
        lnorm[i]=std::sqrt(4*pi/(2*i+1.));
583

584
585
586
      for (size_t j=0; j<blm.size(); ++j)
        slm[j].SetToZero();

587
      for (size_t icomp=0; icomp<ncomp; ++icomp)
588
        {
589
        bool separate = ncomp>1;
590
        {
591
        auto m1 = cube.template subarray<2>({supp,supp,0,icomp},{ntheta,nphi,0,0});
592
593
        decorrect(m1,0);
        sharp_alm2map_adjoint(a1.Alms().vdata(), m1.data(), *ginfo, *ainfo, 0, nthreads);
594
595
        for (size_t m=0; m<=lmax; ++m)
          for (size_t l=m; l<=lmax; ++l)
596
597
598
599
600
            if (separate)
              slm[icomp](l,m) += conj(a1(l,m))*blm[icomp](l,0).real()*T(lnorm[l]);
            else
              for (size_t j=0; j<blm.size(); ++j)
                slm[j](l,m) += conj(a1(l,m))*blm[j](l,0).real()*T(lnorm[l]);
601
602
603
        }
        for (size_t k=1; k<=kmax; ++k)
          {
604
605
          auto m1 = cube.template subarray<2>({supp,supp,2*k-1,icomp},{ntheta,nphi,0,0});
          auto m2 = cube.template subarray<2>({supp,supp,2*k  ,icomp},{ntheta,nphi,0,0});
606
607
608
609
610
611
612
613
          decorrect(m1,k);
          decorrect(m2,k);

          sharp_alm2map_spin_adjoint(k, a1.Alms().vdata(), a2.Alms().vdata(), m1.data(),
            m2.data(), *ginfo, *ainfo, 0, nthreads);
          for (size_t m=0; m<=lmax; ++m)
            for (size_t l=m; l<=lmax; ++l)
              if (l>=k)
614
615
                {
                if (separate)
616
                  {
617
                  auto tmp = conj(blm[icomp](l,k))*T(-2*lnorm[l]);
618
619
                  slm[icomp](l,m) += conj(a1(l,m))*tmp.real();
                  slm[icomp](l,m) -= conj(a2(l,m))*tmp.imag();
620
                  }
621
622
623
                else
                  for (size_t j=0; j<blm.size(); ++j)
                    {
624
                    auto tmp = conj(blm[j](l,k))*T(-2*lnorm[l]);
625
626
627
628
                    slm[j](l,m) += conj(a1(l,m))*tmp.real();
                    slm[j](l,m) -= conj(a2(l,m))*tmp.imag();
                    }
                }
629
          }
630
631
        }
      }
632
633
  };

634
635
636
double epsilon_guess(size_t support, double ofactor)
  { return std::sqrt(12.)*std::exp(-(support*ofactor)); }

Martin Reinecke's avatar
Martin Reinecke committed
637
}
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
638

Martin Reinecke's avatar
Martin Reinecke committed
639
using detail_interpol_ng::Interpolator;
640
using detail_interpol_ng::epsilon_guess;
641

Martin Reinecke's avatar
Martin Reinecke committed
642
}
643

Martin Reinecke's avatar
Martin Reinecke committed
644
#endif