fft.h 36.5 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
/*
This file is part of pocketfft.

Martin Reinecke's avatar
Martin Reinecke committed
4
Copyright (C) 2010-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Copyright (C) 2019 Peter Bell

For the odd-sized DCT-IV transforms:
  Copyright (C) 2003, 2007-14 Matteo Frigo
  Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology

Authors: Martin Reinecke, Peter Bell

All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
  list of conditions and the following disclaimer in the documentation and/or
  other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its contributors may
  be used to endorse or promote products derived from this software without
  specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#ifndef MRUTIL_FFT_H
#define MRUTIL_FFT_H
41

42
#include "mr_util/math/fft1d.h"
Martin Reinecke's avatar
Martin Reinecke committed
43

Martin Reinecke's avatar
Martin Reinecke committed
44
45
46
47
48
49
50
51
52
53
#ifndef POCKETFFT_CACHE_SIZE
#define POCKETFFT_CACHE_SIZE 16
#endif

#include <cmath>
#include <cstdlib>
#include <stdexcept>
#include <memory>
#include <vector>
#include <complex>
54
#include <algorithm>
Martin Reinecke's avatar
Martin Reinecke committed
55
56
57
#if POCKETFFT_CACHE_SIZE!=0
#include <array>
#endif
58
59
60
#include "mr_util/infra/threading.h"
#include "mr_util/infra/simd.h"
#include "mr_util/infra/mav.h"
Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
#ifndef MRUTIL_NO_THREADING
#include <mutex>
#endif
Martin Reinecke's avatar
Martin Reinecke committed
64
65
66
67
68

namespace mr {

namespace detail_fft {

Martin Reinecke's avatar
Martin Reinecke committed
69
using shape_t=fmav_info::shape_t;
Martin Reinecke's avatar
Martin Reinecke committed
70
using stride_t=fmav_info::stride_t;
Martin Reinecke's avatar
Martin Reinecke committed
71

Martin Reinecke's avatar
Martin Reinecke committed
72
73
74
75
76
constexpr bool FORWARD  = true,
               BACKWARD = false;

struct util // hack to avoid duplicate symbols
  {
Martin Reinecke's avatar
Martin Reinecke committed
77
  static void sanity_check_axes(size_t ndim, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
78
79
    {
    shape_t tmp(ndim,0);
Martin Reinecke's avatar
Martin Reinecke committed
80
    if (axes.empty()) throw std::invalid_argument("no axes specified");
Martin Reinecke's avatar
Martin Reinecke committed
81
82
    for (auto ax : axes)
      {
Martin Reinecke's avatar
Martin Reinecke committed
83
84
      if (ax>=ndim) throw std::invalid_argument("bad axis number");
      if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly");
Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
      }
    }

Martin Reinecke's avatar
Martin Reinecke committed
88
89
  static MRUTIL_NOINLINE void sanity_check_onetype(const fmav_info &a1,
    const fmav_info &a2, bool inplace, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
90
    {
Martin Reinecke's avatar
Martin Reinecke committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    sanity_check_axes(a1.ndim(), axes);
    MR_assert(a1.conformable(a2), "array sizes are not conformable");
    if (inplace) MR_assert(a1.stride()==a2.stride(), "stride mismatch");
    }
  static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
    const fmav_info &ar, const shape_t &axes)
    {
    sanity_check_axes(ac.ndim(), axes);
    MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
    for (size_t i=0; i<ac.ndim(); ++i)
      MR_assert(ac.shape(i)== (i==axes.back()) ? (ar.shape(i)/2+1) : ar.shape(i),
        "axis length mismatch");
    }
  static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
    const fmav_info &ar, const size_t axis)
    {
    if (axis>=ac.ndim()) throw std::invalid_argument("bad axis number");
    MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
    for (size_t i=0; i<ac.ndim(); ++i)
      MR_assert(ac.shape(i)== (i==axis) ? (ar.shape(i)/2+1) : ar.shape(i),
        "axis length mismatch");
Martin Reinecke's avatar
Martin Reinecke committed
112
113
114
    }

#ifdef MRUTIL_NO_THREADING
Martin Reinecke's avatar
Martin Reinecke committed
115
  static size_t thread_count (size_t /*nthreads*/, const fmav_info &/*info*/,
Martin Reinecke's avatar
Martin Reinecke committed
116
117
118
    size_t /*axis*/, size_t /*vlen*/)
    { return 1; }
#else
Martin Reinecke's avatar
Martin Reinecke committed
119
  static size_t thread_count (size_t nthreads, const fmav_info &info,
Martin Reinecke's avatar
Martin Reinecke committed
120
121
122
    size_t axis, size_t vlen)
    {
    if (nthreads==1) return 1;
Martin Reinecke's avatar
Martin Reinecke committed
123
124
125
    size_t size = info.size();
    size_t parallel = size / (info.shape(axis) * vlen);
    if (info.shape(axis) < 1000)
Martin Reinecke's avatar
Martin Reinecke committed
126
      parallel /= 4;
Martin Reinecke's avatar
Martin Reinecke committed
127
    size_t max_threads = (nthreads==0) ? mr::max_threads() : nthreads;
128
    return std::max(size_t(1), std::min(parallel, max_threads));
Martin Reinecke's avatar
Martin Reinecke committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    }
#endif
  };


//
// sine/cosine transforms
//

template<typename T0> class T_dct1
  {
  private:
    pocketfft_r<T0> fftplan;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
144
    MRUTIL_NOINLINE T_dct1(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
145
146
      : fftplan(2*(length-1)) {}

Martin Reinecke's avatar
Martin Reinecke committed
147
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho,
Martin Reinecke's avatar
Martin Reinecke committed
148
149
150
151
152
153
      int /*type*/, bool /*cosine*/) const
      {
      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
      size_t N=fftplan.length(), n=N/2+1;
      if (ortho)
        { c[0]*=sqrt2; c[n-1]*=sqrt2; }
154
      aligned_array<T> tmp(N);
Martin Reinecke's avatar
Martin Reinecke committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
      tmp[0] = c[0];
      for (size_t i=1; i<n; ++i)
        tmp[i] = tmp[N-i] = c[i];
      fftplan.exec(tmp.data(), fct, true);
      c[0] = tmp[0];
      for (size_t i=1; i<n; ++i)
        c[i] = tmp[2*i-1];
      if (ortho)
        { c[0]*=sqrt2*T0(0.5); c[n-1]*=sqrt2*T0(0.5); }
      }

    size_t length() const { return fftplan.length()/2+1; }
  };

template<typename T0> class T_dst1
  {
  private:
    pocketfft_r<T0> fftplan;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
175
    MRUTIL_NOINLINE T_dst1(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
176
177
      : fftplan(2*(length+1)) {}

Martin Reinecke's avatar
Martin Reinecke committed
178
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct,
Martin Reinecke's avatar
Martin Reinecke committed
179
180
181
      bool /*ortho*/, int /*type*/, bool /*cosine*/) const
      {
      size_t N=fftplan.length(), n=N/2-1;
182
      aligned_array<T> tmp(N);
Martin Reinecke's avatar
Martin Reinecke committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
      tmp[0] = tmp[n+1] = c[0]*0;
      for (size_t i=0; i<n; ++i)
        { tmp[i+1]=c[i]; tmp[N-1-i]=-c[i]; }
      fftplan.exec(tmp.data(), fct, true);
      for (size_t i=0; i<n; ++i)
        c[i] = -tmp[2*i+2];
      }

    size_t length() const { return fftplan.length()/2-1; }
  };

template<typename T0> class T_dcst23
  {
  private:
    pocketfft_r<T0> fftplan;
Martin Reinecke's avatar
Martin Reinecke committed
198
    std::vector<T0> twiddle;
Martin Reinecke's avatar
Martin Reinecke committed
199
200

  public:
Martin Reinecke's avatar
Martin Reinecke committed
201
    MRUTIL_NOINLINE T_dcst23(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
202
203
204
205
206
207
208
      : fftplan(length), twiddle(length)
      {
      UnityRoots<T0,Cmplx<T0>> tw(4*length);
      for (size_t i=0; i<length; ++i)
        twiddle[i] = tw[i+1].r;
      }

Martin Reinecke's avatar
Martin Reinecke committed
209
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho,
Martin Reinecke's avatar
Martin Reinecke committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
      int type, bool cosine) const
      {
      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
      size_t N=length();
      size_t NS2 = (N+1)/2;
      if (type==2)
        {
        if (!cosine)
          for (size_t k=1; k<N; k+=2)
            c[k] = -c[k];
        c[0] *= 2;
        if ((N&1)==0) c[N-1]*=2;
        for (size_t k=1; k<N-1; k+=2)
          MPINPLACE(c[k+1], c[k]);
        fftplan.exec(c, fct, false);
        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
          {
          T t1 = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
          T t2 = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
          c[k] = T0(0.5)*(t1+t2); c[kc]=T0(0.5)*(t1-t2);
          }
        if ((N&1)==0)
          c[NS2] *= twiddle[NS2-1];
        if (!cosine)
          for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
235
            std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
236
237
238
239
240
241
242
        if (ortho) c[0]*=sqrt2*T0(0.5);
        }
      else
        {
        if (ortho) c[0]*=sqrt2;
        if (!cosine)
          for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
243
            std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
          {
          T t1=c[k]+c[kc], t2=c[k]-c[kc];
          c[k] = twiddle[k-1]*t2+twiddle[kc-1]*t1;
          c[kc]= twiddle[k-1]*t1-twiddle[kc-1]*t2;
          }
        if ((N&1)==0)
          c[NS2] *= 2*twiddle[NS2-1];
        fftplan.exec(c, fct, true);
        for (size_t k=1; k<N-1; k+=2)
          MPINPLACE(c[k], c[k+1]);
        if (!cosine)
          for (size_t k=1; k<N; k+=2)
            c[k] = -c[k];
        }
      }

    size_t length() const { return fftplan.length(); }
  };

template<typename T0> class T_dcst4
  {
  private:
    size_t N;
Martin Reinecke's avatar
Martin Reinecke committed
268
269
    std::unique_ptr<pocketfft_c<T0>> fft;
    std::unique_ptr<pocketfft_r<T0>> rfft;
270
    aligned_array<Cmplx<T0>> C2;
Martin Reinecke's avatar
Martin Reinecke committed
271
272

  public:
Martin Reinecke's avatar
Martin Reinecke committed
273
    MRUTIL_NOINLINE T_dcst4(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
274
275
276
277
278
279
280
281
282
283
284
285
286
      : N(length),
        fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
        rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
        C2((N&1) ? 0 : N/2)
      {
      if ((N&1)==0)
        {
        UnityRoots<T0,Cmplx<T0>> tw(16*N);
        for (size_t i=0; i<N/2; ++i)
          C2[i] = conj(tw[8*i+1]);
        }
      }

Martin Reinecke's avatar
Martin Reinecke committed
287
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct,
Martin Reinecke's avatar
Martin Reinecke committed
288
289
290
291
292
      bool /*ortho*/, int /*type*/, bool cosine) const
      {
      size_t n2 = N/2;
      if (!cosine)
        for (size_t k=0, kc=N-1; k<n2; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
293
          std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
294
295
296
297
298
299
      if (N&1)
        {
        // The following code is derived from the FFTW3 function apply_re11()
        // and is released under the 3-clause BSD license with friendly
        // permission of Matteo Frigo and Steven G. Johnson.

300
        aligned_array<T> y(N);
Martin Reinecke's avatar
Martin Reinecke committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        {
        size_t i=0, m=n2;
        for (; m<N; ++i, m+=4)
          y[i] = c[m];
        for (; m<2*N; ++i, m+=4)
          y[i] = -c[2*N-m-1];
        for (; m<3*N; ++i, m+=4)
          y[i] = -c[m-2*N];
        for (; m<4*N; ++i, m+=4)
          y[i] = c[4*N-m-1];
        for (; i<N; ++i, m+=4)
          y[i] = c[m-4*N];
        }
        rfft->exec(y.data(), fct, true);
        {
        auto SGN = [](size_t i)
           {
           constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
           return (i&2) ? -sqrt2 : sqrt2;
           };
        c[n2] = y[0]*SGN(n2+1);
        size_t i=0, i1=1, k=1;
        for (; k<n2; ++i, ++i1, k+=2)
          {
          c[i    ] = y[2*k-1]*SGN(i1)     + y[2*k  ]*SGN(i);
          c[N -i1] = y[2*k-1]*SGN(N -i)   - y[2*k  ]*SGN(N -i1);
          c[n2-i1] = y[2*k+1]*SGN(n2-i)   - y[2*k+2]*SGN(n2-i1);
          c[n2+i1] = y[2*k+1]*SGN(n2+i+2) + y[2*k+2]*SGN(n2+i1);
          }
        if (k == n2)
          {
          c[i   ] = y[2*k-1]*SGN(i+1) + y[2*k]*SGN(i);
          c[N-i1] = y[2*k-1]*SGN(i+2) + y[2*k]*SGN(i1);
          }
        }

        // FFTW-derived code ends here
        }
      else
        {
        // even length algorithm from
        // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
343
        aligned_array<Cmplx<T>> y(n2);
Martin Reinecke's avatar
Martin Reinecke committed
344
345
346
347
348
349
350
351
        for(size_t i=0; i<n2; ++i)
          {
          y[i].Set(c[2*i],c[N-1-2*i]);
          y[i] *= C2[i];
          }
        fft->exec(y.data(), fct, true);
        for(size_t i=0, ic=n2-1; i<n2; ++i, --ic)
          {
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
352
353
          c[2*i  ] = T0( 2)*(y[i ].r*C2[i ].r-y[i ].i*C2[i ].i);
          c[2*i+1] = T0(-2)*(y[ic].i*C2[ic].r+y[ic].r*C2[ic].i);
Martin Reinecke's avatar
Martin Reinecke committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
          }
        }
      if (!cosine)
        for (size_t k=1; k<N; k+=2)
          c[k] = -c[k];
      }

    size_t length() const { return N; }
  };


//
// multi-D infrastructure
//

Martin Reinecke's avatar
Martin Reinecke committed
369
template<typename T> std::shared_ptr<T> get_plan(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
370
371
  {
#if POCKETFFT_CACHE_SIZE==0
Martin Reinecke's avatar
Martin Reinecke committed
372
  return std::make_shared<T>(length);
Martin Reinecke's avatar
Martin Reinecke committed
373
374
#else
  constexpr size_t nmax=POCKETFFT_CACHE_SIZE;
Martin Reinecke's avatar
Martin Reinecke committed
375
376
  static std::array<std::shared_ptr<T>, nmax> cache;
  static std::array<size_t, nmax> last_access{{0}};
Martin Reinecke's avatar
Martin Reinecke committed
377
378
  static size_t access_counter = 0;

Martin Reinecke's avatar
Martin Reinecke committed
379
  auto find_in_cache = [&]() -> std::shared_ptr<T>
Martin Reinecke's avatar
Martin Reinecke committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    {
    for (size_t i=0; i<nmax; ++i)
      if (cache[i] && (cache[i]->length()==length))
        {
        // no need to update if this is already the most recent entry
        if (last_access[i]!=access_counter)
          {
          last_access[i] = ++access_counter;
          // Guard against overflow
          if (access_counter == 0)
            last_access.fill(0);
          }
        return cache[i];
        }

    return nullptr;
    };
Martin Reinecke's avatar
Martin Reinecke committed
397
#ifdef MRUTIL_NO_THREADING
Martin Reinecke's avatar
Martin Reinecke committed
398
399
  auto p = find_in_cache();
  if (p) return p;
Martin Reinecke's avatar
Martin Reinecke committed
400
  auto plan = std::make_shared<T>(length);
Martin Reinecke's avatar
Martin Reinecke committed
401
402
403
404
405
406
407
408
409
  size_t lru = 0;
  for (size_t i=1; i<nmax; ++i)
    if (last_access[i] < last_access[lru])
      lru = i;

  cache[lru] = plan;
  last_access[lru] = ++access_counter;
  return plan;
#else
Martin Reinecke's avatar
Martin Reinecke committed
410
  static std::mutex mut;
Martin Reinecke's avatar
Martin Reinecke committed
411
412

  {
Martin Reinecke's avatar
Martin Reinecke committed
413
  std::lock_guard<std::mutex> lock(mut);
Martin Reinecke's avatar
Martin Reinecke committed
414
415
416
  auto p = find_in_cache();
  if (p) return p;
  }
Martin Reinecke's avatar
Martin Reinecke committed
417
  auto plan = std::make_shared<T>(length);
Martin Reinecke's avatar
Martin Reinecke committed
418
  {
Martin Reinecke's avatar
Martin Reinecke committed
419
  std::lock_guard<std::mutex> lock(mut);
Martin Reinecke's avatar
Martin Reinecke committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
  auto p = find_in_cache();
  if (p) return p;

  size_t lru = 0;
  for (size_t i=1; i<nmax; ++i)
    if (last_access[i] < last_access[lru])
      lru = i;

  cache[lru] = plan;
  last_access[lru] = ++access_counter;
  }
  return plan;
#endif
#endif
  }

template<size_t N> class multi_iter
  {
  private:
Martin Reinecke's avatar
Martin Reinecke committed
439
440
441
442
443
    shape_t shp, pos;
    stride_t str_i, str_o;
    size_t cshp_i, cshp_o, rem;
    ptrdiff_t cstr_i, cstr_o, sstr_i, sstr_o, p_ii, p_i[N], p_oi, p_o[N];
    bool uni_i, uni_o;
Martin Reinecke's avatar
Martin Reinecke committed
444
445
446

    void advance_i()
      {
Martin Reinecke's avatar
Martin Reinecke committed
447
      for (size_t i=0; i<pos.size(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
448
        {
Martin Reinecke's avatar
Martin Reinecke committed
449
450
451
        p_ii += str_i[i];
        p_oi += str_o[i];
        if (++pos[i] < shp[i])
Martin Reinecke's avatar
Martin Reinecke committed
452
453
          return;
        pos[i] = 0;
Martin Reinecke's avatar
Martin Reinecke committed
454
455
        p_ii -= ptrdiff_t(shp[i])*str_i[i];
        p_oi -= ptrdiff_t(shp[i])*str_o[i];
Martin Reinecke's avatar
Martin Reinecke committed
456
457
458
459
        }
      }

  public:
Martin Reinecke's avatar
Martin Reinecke committed
460
461
462
    multi_iter(const fmav_info &iarr, const fmav_info &oarr, size_t idim,
      size_t nshares, size_t myshare)
      : rem(iarr.size()/iarr.shape(idim)), sstr_i(0), sstr_o(0), p_ii(0), p_oi(0)
Martin Reinecke's avatar
Martin Reinecke committed
463
      {
Martin Reinecke's avatar
Martin Reinecke committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
      MR_assert(oarr.ndim()==iarr.ndim(), "dimension mismatch");
      MR_assert(iarr.ndim()>=1, "not enough dimensions");
      // Sort the extraneous dimensions in order of ascending output stride;
      // this should improve overall cache re-use and avoid clashes between
      // threads as much as possible.
      shape_t idx(iarr.ndim());
      std::iota(idx.begin(), idx.end(), 0);
      sort(idx.begin(), idx.end(),
        [&oarr](size_t i1, size_t i2) {return oarr.stride(i1) < oarr.stride(i2);});
      for (auto i: idx)
        if (i!=idim)
          {
          pos.push_back(0);
          MR_assert(iarr.shape(i)==oarr.shape(i), "shape mismatch");
          shp.push_back(iarr.shape(i));
          str_i.push_back(iarr.stride(i));
          str_o.push_back(oarr.stride(i));
          }
      MR_assert(idim<iarr.ndim(), "bad active dimension");
      cstr_i = iarr.stride(idim);
      cstr_o = oarr.stride(idim);
      cshp_i = iarr.shape(idim);
      cshp_o = oarr.shape(idim);

// collapse unneeded dimensions
      bool done = false;
      while(!done)
        {
        done=true;
        for (size_t i=1; i<shp.size(); ++i)
          if ((str_i[i] == str_i[i-1]*ptrdiff_t(shp[i-1]))
           && (str_o[i] == str_o[i-1]*ptrdiff_t(shp[i-1])))
            {
            shp[i-1] *= shp[i];
            str_i.erase(str_i.begin()+ptrdiff_t(i));
            str_o.erase(str_o.begin()+ptrdiff_t(i));
            shp.erase(shp.begin()+ptrdiff_t(i));
            pos.pop_back();
            done=false;
            }
        }
      if (pos.size()>0)
        {
        sstr_i = str_i[0];
        sstr_o = str_o[0];
        }

Martin Reinecke's avatar
Martin Reinecke committed
511
      if (nshares==1) return;
Martin Reinecke's avatar
Martin Reinecke committed
512
513
      if (nshares==0) throw std::runtime_error("can't run with zero threads");
      if (myshare>=nshares) throw std::runtime_error("impossible share requested");
Martin Reinecke's avatar
Martin Reinecke committed
514
515
516
517
518
519
520
      size_t nbase = rem/nshares;
      size_t additional = rem%nshares;
      size_t lo = myshare*nbase + ((myshare<additional) ? myshare : additional);
      size_t hi = lo+nbase+(myshare<additional);
      size_t todo = hi-lo;

      size_t chunk = rem;
Martin Reinecke's avatar
Martin Reinecke committed
521
      for (size_t i2=0, i=pos.size()-1; i2<pos.size(); ++i2,--i)
Martin Reinecke's avatar
Martin Reinecke committed
522
        {
Martin Reinecke's avatar
Martin Reinecke committed
523
        chunk /= shp[i];
Martin Reinecke's avatar
Martin Reinecke committed
524
525
        size_t n_advance = lo/chunk;
        pos[i] += n_advance;
Martin Reinecke's avatar
Martin Reinecke committed
526
527
        p_ii += ptrdiff_t(n_advance)*str_i[i];
        p_oi += ptrdiff_t(n_advance)*str_o[i];
Martin Reinecke's avatar
Martin Reinecke committed
528
529
        lo -= n_advance*chunk;
        }
Martin Reinecke's avatar
Martin Reinecke committed
530
      MR_assert(lo==0, "must not happen");
Martin Reinecke's avatar
Martin Reinecke committed
531
532
533
534
      rem = todo;
      }
    void advance(size_t n)
      {
Martin Reinecke's avatar
Martin Reinecke committed
535
      if (rem<n) throw std::runtime_error("underrun");
Martin Reinecke's avatar
Martin Reinecke committed
536
537
538
539
540
541
      for (size_t i=0; i<n; ++i)
        {
        p_i[i] = p_ii;
        p_o[i] = p_oi;
        advance_i();
        }
Martin Reinecke's avatar
Martin Reinecke committed
542
543
544
545
546
547
      uni_i = uni_o = true;
      for (size_t i=1; i<n; ++i)
        {
        uni_i = uni_i && (p_i[i]-p_i[i-1] == sstr_i);
        uni_o = uni_o && (p_o[i]-p_o[i-1] == sstr_o);
        }
Martin Reinecke's avatar
Martin Reinecke committed
548
549
      rem -= n;
      }
Martin Reinecke's avatar
Martin Reinecke committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t iofs_uni(size_t j, size_t i) const { return p_i[0] + ptrdiff_t(j)*sstr_i + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i)*cstr_o; }
    ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i)*cstr_o; }
    ptrdiff_t oofs_uni(size_t j, size_t i) const { return p_o[0] + ptrdiff_t(j)*sstr_o + ptrdiff_t(i)*cstr_o; }
    bool uniform_i() const { return uni_i; } 
    ptrdiff_t unistride_i() const { return sstr_i; } 
    bool uniform_o() const { return uni_o; } 
    ptrdiff_t unistride_o() const { return sstr_o; } 
    size_t length_in() const { return cshp_i; }
    size_t length_out() const { return cshp_o; }
    ptrdiff_t stride_in() const { return cstr_i; }
    ptrdiff_t stride_out() const { return cstr_o; }
Martin Reinecke's avatar
Martin Reinecke committed
564
565
566
567
568
569
570
    size_t remaining() const { return rem; }
  };

class rev_iter
  {
  private:
    shape_t pos;
Martin Reinecke's avatar
Martin Reinecke committed
571
    fmav_info arr;
Martin Reinecke's avatar
Martin Reinecke committed
572
573
    std::vector<char> rev_axis;
    std::vector<char> rev_jump;
Martin Reinecke's avatar
Martin Reinecke committed
574
575
576
577
578
579
    size_t last_axis, last_size;
    shape_t shp;
    ptrdiff_t p, rp;
    size_t rem;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
580
    rev_iter(const fmav_info &arr_, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
      : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
        rev_jump(arr_.ndim(), 1), p(0), rp(0)
      {
      for (auto ax: axes)
        rev_axis[ax]=1;
      last_axis = axes.back();
      last_size = arr.shape(last_axis)/2 + 1;
      shp = arr.shape();
      shp[last_axis] = last_size;
      rem=1;
      for (auto i: shp)
        rem *= i;
      }
    void advance()
      {
      --rem;
      for (int i_=int(pos.size())-1; i_>=0; --i_)
        {
        auto i = size_t(i_);
        p += arr.stride(i);
        if (!rev_axis[i])
          rp += arr.stride(i);
        else
          {
          rp -= arr.stride(i);
          if (rev_jump[i])
            {
            rp += ptrdiff_t(arr.shape(i))*arr.stride(i);
            rev_jump[i] = 0;
            }
          }
        if (++pos[i] < shp[i])
          return;
        pos[i] = 0;
        p -= ptrdiff_t(shp[i])*arr.stride(i);
        if (rev_axis[i])
          {
          rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i);
          rev_jump[i] = 1;
          }
        else
          rp -= ptrdiff_t(shp[i])*arr.stride(i);
        }
      }
    ptrdiff_t ofs() const { return p; }
    ptrdiff_t rev_ofs() const { return rp; }
    size_t remaining() const { return rem; }
  };

Martin Reinecke's avatar
Martin Reinecke committed
630
631
template<typename T, typename T0> aligned_array<T> alloc_tmp
  (const fmav_info &info, size_t axsize)
Martin Reinecke's avatar
Martin Reinecke committed
632
  {
Martin Reinecke's avatar
Martin Reinecke committed
633
634
635
  auto othersize = info.size()/axsize;
  constexpr auto vlen = native_simd<T0>::size();
  auto tmpsize = axsize*((othersize>=vlen) ? vlen : 1);
Martin Reinecke's avatar
Martin Reinecke committed
636
  return aligned_array<T>(tmpsize);
Martin Reinecke's avatar
Martin Reinecke committed
637
638
  }

Martin Reinecke's avatar
Martin Reinecke committed
639
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input_j1(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
640
  const fmav<Cmplx<T>> &src, Cmplx<native_simd<T>> *MRUTIL_RESTRICT dst)
Martin Reinecke's avatar
Martin Reinecke committed
641
  {
Martin Reinecke's avatar
Martin Reinecke committed
642
643
644
645
  auto ptr = &src[it.iofs_uni(0,0)];
  auto istr = it.stride_in();
  size_t i=0;
  for (; i<it.length_in(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
646
    {
Martin Reinecke's avatar
Martin Reinecke committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    Cmplx<native_simd<T>> stmp;
    for (size_t j=0; j<vlen; ++j)
      {
      auto tmp = ptr[j+i*istr];
      stmp.r[j] = tmp.r;
      stmp.i[j] = tmp.i;
      }
    dst[i] = stmp;
    }
  }
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input_i1(const multi_iter<vlen> &it,
  const fmav<Cmplx<T>> &src, Cmplx<native_simd<T>> *MRUTIL_RESTRICT dst)
  {
  auto ptr = &src[it.iofs_uni(0,0)];
  auto jstr = it.unistride_i();
  auto istr = it.stride_in();
Martin Reinecke's avatar
Martin Reinecke committed
663
664
665
666
667
668
669
670
671
672
673
      for (size_t i=0; i<it.length_in(); ++i)
        {
        Cmplx<native_simd<T>> stmp;
        for (size_t j=0; j<vlen; ++j)
          {
          auto tmp = ptr[j*jstr+i];
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
Martin Reinecke's avatar
Martin Reinecke committed
674
675
676
677
678
679
680
681
682
683
  }
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input(const multi_iter<vlen> &it,
  const fmav<Cmplx<T>> &src, Cmplx<native_simd<T>> *MRUTIL_RESTRICT dst)
  {
  if (it.uniform_i())
    {
    auto ptr = &src[it.iofs_uni(0,0)];
    auto jstr = it.unistride_i();
    auto istr = it.stride_in();
    if (istr==1)
684
685
686
687
688
689
690
691
692
693
694
      for (ptrdiff_t i=0; i<it.length_in(); ++i)
        {
        Cmplx<native_simd<T>> stmp;
        for (ptrdiff_t j=0; j<vlen; ++j)
          {
          auto tmp = ptr[j*jstr+i];
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
Martin Reinecke's avatar
Martin Reinecke committed
695
    else if (jstr==1)
696
697
698
699
700
701
702
703
704
705
706
      for (ptrdiff_t i=0; i<it.length_in(); ++i)
        {
        Cmplx<native_simd<T>> stmp;
        for (ptrdiff_t j=0; j<vlen; ++j)
          {
          auto tmp = ptr[j+i*istr];
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
Martin Reinecke's avatar
Martin Reinecke committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    else
      for (size_t i=0; i<it.length_in(); ++i)
        {
        Cmplx<native_simd<T>> stmp;
        for (size_t j=0; j<vlen; ++j)
          {
          auto tmp = src[it.iofs_uni(j,i)];
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
    }
  else
    for (size_t i=0; i<it.length_in(); ++i)
722
      {
Martin Reinecke's avatar
Martin Reinecke committed
723
724
725
726
727
728
729
730
      Cmplx<native_simd<T>> stmp;
      for (size_t j=0; j<vlen; ++j)
        {
        auto tmp = src[it.iofs(j,i)];
        stmp.r[j] = tmp.r;
        stmp.i[j] = tmp.i;
        }
      dst[i] = stmp;
731
      }
Martin Reinecke's avatar
Martin Reinecke committed
732
733
  }

Martin Reinecke's avatar
Martin Reinecke committed
734
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
735
  const fmav<T> &src, native_simd<T> *MRUTIL_RESTRICT dst)
Martin Reinecke's avatar
Martin Reinecke committed
736
  {
Martin Reinecke's avatar
Martin Reinecke committed
737
  if (it.uniform_i())
738
    for (size_t i=0; i<it.length_in(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
739
740
741
      for (size_t j=0; j<vlen; ++j)
        dst[i][j] = src[it.iofs_uni(j,i)];
  else
742
    for (size_t i=0; i<it.length_in(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
743
744
      for (size_t j=0; j<vlen; ++j)
        dst[i][j] = src[it.iofs(j,i)];
Martin Reinecke's avatar
Martin Reinecke committed
745
746
  }

Martin Reinecke's avatar
Martin Reinecke committed
747
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
748
  const fmav<T> &src, T *MRUTIL_RESTRICT dst)
Martin Reinecke's avatar
Martin Reinecke committed
749
750
751
752
753
754
  {
  if (dst == &src[it.iofs(0)]) return;  // in-place
  for (size_t i=0; i<it.length_in(); ++i)
    dst[i] = src[it.iofs(i)];
  }

Martin Reinecke's avatar
Martin Reinecke committed
755
template<typename T, size_t vlen> MRUTIL_NOINLINE void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
756
  const Cmplx<native_simd<T>> *MRUTIL_RESTRICT src, fmav<Cmplx<T>> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
757
  {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
758
  auto ptr=dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
759
  if (it.uniform_o())
760
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
761
762
763
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs_uni(j,i)].Set(src[i].r[j],src[i].i[j]);
  else
764
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
765
766
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]);
Martin Reinecke's avatar
Martin Reinecke committed
767
768
  }

Martin Reinecke's avatar
Martin Reinecke committed
769
template<typename T, size_t vlen> MRUTIL_NOINLINE void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
770
  const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
771
  {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
772
  auto ptr=dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
773
  if (it.uniform_o())
774
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
775
776
777
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs_uni(j,i)] = src[i][j];
  else
778
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
779
780
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i)] = src[i][j];
Martin Reinecke's avatar
Martin Reinecke committed
781
782
  }

Martin Reinecke's avatar
Martin Reinecke committed
783
template<typename T, size_t vlen> MRUTIL_NOINLINE void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
784
  const T *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
785
  {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
786
  auto ptr=dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
787
788
  if (src == &dst[it.oofs(0)]) return;  // in-place
  for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
789
    ptr[it.oofs(i)] = src[i];
Martin Reinecke's avatar
Martin Reinecke committed
790
791
  }

Martin Reinecke's avatar
Martin Reinecke committed
792
template <typename T> struct add_vec { using type = native_simd<T>; };
Martin Reinecke's avatar
Martin Reinecke committed
793
template <typename T> struct add_vec<Cmplx<T>>
Martin Reinecke's avatar
Martin Reinecke committed
794
  { using type = Cmplx<native_simd<T>>; };
Martin Reinecke's avatar
Martin Reinecke committed
795
796
797
template <typename T> using add_vec_t = typename add_vec<T>::type;

template<typename Tplan, typename T, typename T0, typename Exec>
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
798
MRUTIL_NOINLINE void general_nd(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
799
800
801
  const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
  const bool allow_inplace=true)
  {
Martin Reinecke's avatar
Martin Reinecke committed
802
  std::shared_ptr<Tplan> plan;
Martin Reinecke's avatar
Martin Reinecke committed
803
804
805
806
807
808
809
810

  for (size_t iax=0; iax<axes.size(); ++iax)
    {
    size_t len=in.shape(axes[iax]);
    if ((!plan) || (len!=plan->length()))
      plan = get_plan<Tplan>(len);

    execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
811
      util::thread_count(nthreads, in, axes[iax], native_simd<T0>::size()),
Martin Reinecke's avatar
rework    
Martin Reinecke committed
812
      [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
813
814
        constexpr auto vlen = native_simd<T0>::size();
        auto storage = alloc_tmp<T,T0>(in, len);
Martin Reinecke's avatar
Martin Reinecke committed
815
        const auto &tin(iax==0? in : out);
Martin Reinecke's avatar
rework    
Martin Reinecke committed
816
        multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
817
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
818
819
820
821
822
823
824
825
826
827
828
        if (vlen>1)
          while (it.remaining()>=vlen)
            {
            it.advance(vlen);
            auto tdatav = reinterpret_cast<add_vec_t<T> *>(storage.data());
            exec(it, tin, out, tdatav, *plan, fct);
            }
#endif
        while (it.remaining()>0)
          {
          it.advance(1);
Martin Reinecke's avatar
Martin Reinecke committed
829
          auto buf = allow_inplace && it.stride_out() == 1 ?
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
830
            &out.vraw(it.oofs(0)) : reinterpret_cast<T *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
831
832
833
834
835
836
837
838
839
840
841
          exec(it, tin, out, buf, *plan, fct);
          }
      });  // end of parallel region
    fct = T0(1); // factor has been applied, use 1 for remaining axes
    }
  }

struct ExecC2C
  {
  bool forward;

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
842
843
844
  template <typename T0, typename T, size_t vlen> void operator() (
    const multi_iter<vlen> &it, const fmav<Cmplx<T0>> &in,
    fmav<Cmplx<T0>> &out, T *buf, const pocketfft_c<T0> &plan, T0 fct) const
Martin Reinecke's avatar
Martin Reinecke committed
845
846
847
848
849
850
851
852
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, forward);
    copy_output(it, buf, out);
    }
  };

template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
853
  const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
854
  {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
855
  auto ptr = dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
856
  for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
857
    ptr[it.oofs(j,0)] = src[0][j];
Martin Reinecke's avatar
Martin Reinecke committed
858
859
860
861
  size_t i=1, i1=1, i2=it.length_out()-1;
  for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
    for (size_t j=0; j<vlen; ++j)
      {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
862
863
      ptr[it.oofs(j,i1)] = src[i][j]+src[i+1][j];
      ptr[it.oofs(j,i2)] = src[i][j]-src[i+1][j];
Martin Reinecke's avatar
Martin Reinecke committed
864
865
866
      }
  if (i<it.length_out())
    for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
867
      ptr[it.oofs(j,i1)] = src[i][j];
Martin Reinecke's avatar
Martin Reinecke committed
868
869
870
  }

template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
871
  const T *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
872
  {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
873
874
  auto ptr = dst.vdata();
  ptr[it.oofs(0)] = src[0];
Martin Reinecke's avatar
Martin Reinecke committed
875
876
877
  size_t i=1, i1=1, i2=it.length_out()-1;
  for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
    {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
878
879
    ptr[it.oofs(i1)] = src[i]+src[i+1];
    ptr[it.oofs(i2)] = src[i]-src[i+1];
Martin Reinecke's avatar
Martin Reinecke committed
880
881
    }
  if (i<it.length_out())
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
882
    ptr[it.oofs(i1)] = src[i];
Martin Reinecke's avatar
Martin Reinecke committed
883
884
885
886
887
  }

struct ExecHartley
  {
  template <typename T0, typename T, size_t vlen> void operator () (
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
888
    const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out,
Martin Reinecke's avatar
Martin Reinecke committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
    T * buf, const pocketfft_r<T0> &plan, T0 fct) const
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, true);
    copy_hartley(it, buf, out);
    }
  };

struct ExecDcst
  {
  bool ortho;
  int type;
  bool cosine;

  template <typename T0, typename T, typename Tplan, size_t vlen>
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
904
905
  void operator () (const multi_iter<vlen> &it, const fmav<T0> &in,
    fmav <T0> &out, T * buf, const Tplan &plan, T0 fct) const
Martin Reinecke's avatar
Martin Reinecke committed
906
907
908
909
910
911
912
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, ortho, type, cosine);
    copy_output(it, buf, out);
    }
  };

Martin Reinecke's avatar
Martin Reinecke committed
913
template<typename T> MRUTIL_NOINLINE void general_r2c(
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
914
  const fmav<T> &in, fmav<Cmplx<T>> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
915
916
917
918
919
  size_t nthreads)
  {
  auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
  size_t len=in.shape(axis);
  execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
920
    util::thread_count(nthreads, in, axis, native_simd<T>::size()),
Martin Reinecke's avatar
rework    
Martin Reinecke committed
921
    [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
922
923
    constexpr auto vlen = native_simd<T>::size();
    auto storage = alloc_tmp<T,T>(in, len);
Martin Reinecke's avatar
rework    
Martin Reinecke committed
924
    multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
925
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
926
927
928
929
    if (vlen>1)
      while (it.remaining()>=vlen)
        {
        it.advance(vlen);
Martin Reinecke's avatar
Martin Reinecke committed
930
        auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
931
932
        copy_input(it, in, tdatav);
        plan->exec(tdatav, fct, true);
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
933
        auto vout = out.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
934
        for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
935
          vout[it.oofs(j,0)].Set(tdatav[0][j]);
Martin Reinecke's avatar
Martin Reinecke committed
936
937
938
939
        size_t i=1, ii=1;
        if (forward)
          for (; i<len-1; i+=2, ++ii)
            for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
940
              vout[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
Martin Reinecke's avatar
Martin Reinecke committed
941
942
943
        else
          for (; i<len-1; i+=2, ++ii)
            for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
944
              vout[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
Martin Reinecke's avatar
Martin Reinecke committed
945
946
        if (i<len)
          for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
947
            vout[it.oofs(j,ii)].Set(tdatav[i][j]);
Martin Reinecke's avatar
Martin Reinecke committed
948
949
950
951
952
953
954
955
        }
#endif
    while (it.remaining()>0)
      {
      it.advance(1);
      auto tdata = reinterpret_cast<T *>(storage.data());
      copy_input(it, in, tdata);
      plan->exec(tdata, fct, true);
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
956
957
      auto vout = out.vdata();
      vout[it.oofs(0)].Set(tdata[0]);
Martin Reinecke's avatar
Martin Reinecke committed
958
959
960
      size_t i=1, ii=1;
      if (forward)
        for (; i<len-1; i+=2, ++ii)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
961
          vout[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
Martin Reinecke's avatar
Martin Reinecke committed
962
963
      else
        for (; i<len-1; i+=2, ++ii)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
964
          vout[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
Martin Reinecke's avatar
Martin Reinecke committed
965
      if (i<len)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
966
        vout[it.oofs(ii)].Set(tdata[i]);
Martin Reinecke's avatar
Martin Reinecke committed
967
968
969
      }
    });  // end of parallel region
  }
Martin Reinecke's avatar
Martin Reinecke committed
970
template<typename T> MRUTIL_NOINLINE void general_c2r(
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
971
  const fmav<Cmplx<T>> &in, fmav<T> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
972
973
974
975
976
  size_t nthreads)
  {
  auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
  size_t len=out.shape(axis);
  execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
977
    util::thread_count(nthreads, in, axis, native_simd<T>::size()),
Martin Reinecke's avatar
rework    
Martin Reinecke committed
978
    [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
979
980
      constexpr auto vlen = native_simd<T>::size();
      auto storage = alloc_tmp<T,T>(out, len);
Martin Reinecke's avatar
rework    
Martin Reinecke committed
981
      multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
982
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
983
984
985
986
      if (vlen>1)
        while (it.remaining()>=vlen)
          {
          it.advance(vlen);
Martin Reinecke's avatar
Martin Reinecke committed
987
          auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
988
989
990
991
992
993
994
          for (size_t j=0; j<vlen; ++j)
            tdatav[0][j]=in[it.iofs(j,0)].r;
          {
          size_t i=1, ii=1;
          if (forward)
            for (; i<len-1; i+=2, ++ii)
              for (size_t j=0; j<vlen; ++j)
995
996
997
998
                {
                tdatav[i  ][j] =  in[it.iofs(j,ii)].r;
                tdatav[i+1][j] = -in[it.iofs(j,ii)].i;
                }
Martin Reinecke's avatar
Martin Reinecke committed
999
1000
1001
          else
            for (; i<len-1; i+=2, ++ii)
              for (size_t j=0; j<vlen; ++j)
1002
1003
1004
1005
                {
                tdatav[i  ][j] = in[it.iofs(j,ii)].r;
                tdatav[i+1][j] = in[it.iofs(j,ii)].i;
                }
Martin Reinecke's avatar
Martin Reinecke committed
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
          if (i<len)
            for (size_t j=0; j<vlen; ++j)
              tdatav[i][j] = in[it.iofs(j,ii)].r;
          }
          plan->exec(tdatav, fct, false);
          copy_output(it, tdatav, out);
          }
#endif
      while (it.remaining()>0)
        {
        it.advance(1);
        auto tdata = reinterpret_cast<T *>(storage.data());
        tdata[0]=in[it.iofs(0)].r;
        {
        size_t i=1, ii=1;
        if (forward)
          for (; i<len-1; i+=2, ++ii)
1023
1024
1025
1026
            {
            tdata[i  ] =  in[it.iofs(ii)].r;
            tdata[i+1] = -in[it.iofs(ii)].i;
            }
Martin Reinecke's avatar
Martin Reinecke committed
1027
1028
        else
          for (; i<len-1; i+=2, ++ii)
1029
1030
1031
1032
            {
            tdata[i  ] = in[it.iofs(ii)].r;
            tdata[i+1] = in[it.iofs(ii)].i;
            }
Martin Reinecke's avatar
Martin Reinecke committed
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
        if (i<len)
          tdata[i] = in[it.iofs(ii)].r;
        }
        plan->exec(tdata, fct, false);
        copy_output(it, tdata, out);
        }
    });  // end of parallel region
  }

struct ExecR2R
  {
  bool r2c, forward;

  template <typename T0, typename T, size_t vlen> void operator () (
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1047
    const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out, T * buf,
Martin Reinecke's avatar
Martin Reinecke committed
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
    const pocketfft_r<T0> &plan, T0 fct) const
    {
    copy_input(it, in, buf);
    if ((!r2c) && forward)
      for (size_t i=2; i<it.length_out(); i+=2)
        buf[i] = -buf[i];
    plan.exec(buf, fct, forward);
    if (r2c && (!forward))
      for (size_t i=2; i<it.length_out(); i+=2)
        buf[i] = -buf[i];
    copy_output(it, buf, out);
    }
  };

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1062
1063
template<typename T> void c2c(const fmav<std::complex<T>> &in,
  fmav<std::complex<T>> &out, const shape_t &axes, bool forward,
Martin Reinecke's avatar
Martin Reinecke committed
1064
  T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1065
  {
Martin Reinecke's avatar
Martin Reinecke committed
1066
1067
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1068
  fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1069
  fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
Martin Reinecke's avatar
Martin Reinecke committed
1070
  general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward});
Martin Reinecke's avatar
Martin Reinecke committed
1071
1072
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1073
template<typename T> void dct(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
1074
  const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1075
  {
Martin Reinecke's avatar
Martin Reinecke committed
1076
  if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
Martin Reinecke's avatar
Martin Reinecke committed
1077
1078
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
Martin Reinecke's avatar
Martin Reinecke committed
1079
1080
  const ExecDcst exec{ortho, type, true};
  if (type==1)
Martin Reinecke's avatar
Martin Reinecke committed
1081
    general_nd<T_dct1<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1082
  else if (type==4)
Martin Reinecke's avatar
Martin Reinecke committed
1083
    general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1084
  else
Martin Reinecke's avatar
Martin Reinecke committed
1085
    general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1086
1087
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1088
template<typename T> void dst(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
1089
  const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1090
  {
Martin Reinecke's avatar
Martin Reinecke committed
1091
  if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
Martin Reinecke's avatar
Martin Reinecke committed
1092
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
Martin Reinecke's avatar
Martin Reinecke committed
1093
1094
  const ExecDcst exec{ortho, type, false};
  if (type==1)
Martin Reinecke's avatar
Martin Reinecke committed
1095
    general_nd<T_dst1<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1096
  else if (type==4)
Martin Reinecke's avatar
Martin Reinecke committed
1097
    general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1098
  else
Martin Reinecke's avatar
Martin Reinecke committed
1099
    general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1100
1101
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1102
1103
template<typename T> void r2c(const fmav<T> &in,
  fmav<std::complex<T>> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
1104
1105
  size_t nthreads=1)
  {
Martin Reinecke's avatar
Martin Reinecke committed
1106
1107
  util::sanity_check_cr(out, in, axis);
  if (in.size()==0) return;
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1108
  fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
Martin Reinecke's avatar
Martin Reinecke committed
1109
  general_r2c(in, out2, axis, forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1110
1111
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1112
1113
template<typename T> void r2c(const fmav<T> &in,
  fmav<std::complex<T>> &out, const shape_t &axes,
Martin Reinecke's avatar
Martin Reinecke committed
1114
  bool forward, T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1115
  {
Martin Reinecke's avatar
Martin Reinecke committed
1116
1117
1118
  util::sanity_check_cr(out, in, axes);
  if (in.size()==0) return;
  r2c(in, out, axes.back(), forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1119
1120
1121
  if (axes.size()==1) return;

  auto newaxes = shape_t{axes.begin(), --axes.end()};
Martin Reinecke's avatar
Martin Reinecke committed
1122
  c2c(out, out, newaxes, forward, T(1), nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1123
1124
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1125
1126
template<typename T> void c2r(const fmav<std::complex<T>> &in,
  fmav<T> &out,  size_t axis, bool forward, T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1127
  {
Martin Reinecke's avatar
Martin Reinecke committed
1128
1129
  util::sanity_check_cr(in, out, axis);
  if (in.size()==0) return;
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1130
  fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
Martin Reinecke's avatar
Martin Reinecke committed
1131
  general_c2r(in2, out, axis, forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1132
1133
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1134
1135
template<typename T> void c2r(const fmav<std::complex<T>> &in,
  fmav<T> &out, const shape_t &axes, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
1136
1137
1138
  size_t nthreads=1)
  {
  if (axes.size()==1)
Martin Reinecke's avatar
Martin Reinecke committed
1139
1140
1141
    return c2r(in, out, axes[0], forward, fct, nthreads);
  util::sanity_check_cr(in, out, axes);
  if (in.size()==0) return;
1142
  fmav<std::complex<T>> atmp(in.shape());
1143
  auto newaxes = shape_t{axes.begin(), --axes.end()};
Martin Reinecke's avatar
Martin Reinecke committed
1144
1145
  c2c(in, atmp, newaxes, forward, T(1), nthreads);
  c2r(atmp, out, axes.back(), forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1146
1147
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1148
1149
template<typename T> void r2r_fftpack(const fmav<T> &in,
  fmav<T> &out, const shape_t &axes, bool real2hermitian, bool forward,
Martin Reinecke's avatar
Martin Reinecke committed
1150
  T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1151
  {
Martin Reinecke's avatar
Martin Reinecke committed
1152
1153
1154
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
  general_nd<pocketfft_r<T>>(in, out, axes, fct, nthreads,
Martin Reinecke's avatar
Martin Reinecke committed
1155
1156
1157
    ExecR2R{real2hermitian, forward});
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1158
1159
template<typename T> void r2r_separable_hartley(const fmav<T> &in,
  fmav<T> &out, const shape_t &axes, T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1160
  {
Martin Reinecke's avatar
Martin Reinecke committed
1161
1162
1163
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
  general_nd<pocketfft_r<T>>(in, out, axes, fct, nthreads, ExecHartley{},
Martin Reinecke's avatar
Martin Reinecke committed
1164
1165
1166
    false);
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1167
1168
template<typename T> void r2r_genuine_hartley(const fmav<T> &in,
  fmav<T> &out, const shape_t &axes, T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1169
1170
  {
  if (axes.size()==1)
Martin Reinecke's avatar
Martin Reinecke committed
1171
1172
1173
1174
    return r2r_separable_hartley(in, out, axes, fct, nthreads);
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
  shape_t tshp(in.shape());
Martin Reinecke's avatar
Martin Reinecke committed
1175
  tshp[axes.back()] = tshp[axes.back()]/2+1;
1176
  fmav<std::complex<T>> atmp(tshp);
Martin Reinecke's avatar
Martin Reinecke committed
1177
  r2c(in, atmp, axes, true, fct, nthreads);
1178
  FmavIter iin(atmp);
Martin Reinecke's avatar
Martin Reinecke committed
1179
  rev_iter iout(out, axes);
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1180
  auto vout = out.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
1181
1182
1183
  while(iin.remaining()>0)
    {
    auto v = atmp[iin.ofs()];
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1184
1185
    vout[iout.ofs()] = v.real()+v.imag();
    vout[iout.rev_ofs()] = v.real()-v.imag();
Martin Reinecke's avatar
Martin Reinecke committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
    iin.advance(); iout.advance();
    }
  }

} // namespace detail_fft

using detail_fft::FORWARD;
using detail_fft::BACKWARD;
using detail_fft::c2c;
using detail_fft::c2r;
using detail_fft::r2c;
using detail_fft::r2r_fftpack;
using detail_fft::r2r_sep