fft.h 38.6 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
/*
This file is part of pocketfft.

Martin Reinecke's avatar
Martin Reinecke committed
4
Copyright (C) 2010-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Copyright (C) 2019 Peter Bell

For the odd-sized DCT-IV transforms:
  Copyright (C) 2003, 2007-14 Matteo Frigo
  Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology

Authors: Martin Reinecke, Peter Bell

All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
  list of conditions and the following disclaimer in the documentation and/or
  other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its contributors may
  be used to endorse or promote products derived from this software without
  specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#ifndef MRUTIL_FFT_H
#define MRUTIL_FFT_H
41

42
#include "mr_util/math/fft1d.h"
Martin Reinecke's avatar
Martin Reinecke committed
43

Martin Reinecke's avatar
Martin Reinecke committed
44
45
46
47
48
49
50
51
52
53
#ifndef POCKETFFT_CACHE_SIZE
#define POCKETFFT_CACHE_SIZE 16
#endif

#include <cmath>
#include <cstdlib>
#include <stdexcept>
#include <memory>
#include <vector>
#include <complex>
54
#include <algorithm>
Martin Reinecke's avatar
Martin Reinecke committed
55
56
57
#if POCKETFFT_CACHE_SIZE!=0
#include <array>
#endif
58
59
60
#include "mr_util/infra/threading.h"
#include "mr_util/infra/simd.h"
#include "mr_util/infra/mav.h"
Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
#ifndef MRUTIL_NO_THREADING
#include <mutex>
#endif
Martin Reinecke's avatar
Martin Reinecke committed
64
65
66
67
68

namespace mr {

namespace detail_fft {

Martin Reinecke's avatar
Martin Reinecke committed
69
using shape_t=fmav_info::shape_t;
Martin Reinecke's avatar
Martin Reinecke committed
70
using stride_t=fmav_info::stride_t;
Martin Reinecke's avatar
Martin Reinecke committed
71

Martin Reinecke's avatar
Martin Reinecke committed
72
73
74
75
76
constexpr bool FORWARD  = true,
               BACKWARD = false;

struct util // hack to avoid duplicate symbols
  {
Martin Reinecke's avatar
Martin Reinecke committed
77
  static void sanity_check_axes(size_t ndim, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
78
79
    {
    shape_t tmp(ndim,0);
Martin Reinecke's avatar
Martin Reinecke committed
80
    if (axes.empty()) throw std::invalid_argument("no axes specified");
Martin Reinecke's avatar
Martin Reinecke committed
81
82
    for (auto ax : axes)
      {
Martin Reinecke's avatar
Martin Reinecke committed
83
84
      if (ax>=ndim) throw std::invalid_argument("bad axis number");
      if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly");
Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
      }
    }

Martin Reinecke's avatar
Martin Reinecke committed
88
89
  static MRUTIL_NOINLINE void sanity_check_onetype(const fmav_info &a1,
    const fmav_info &a2, bool inplace, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
90
    {
Martin Reinecke's avatar
Martin Reinecke committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    sanity_check_axes(a1.ndim(), axes);
    MR_assert(a1.conformable(a2), "array sizes are not conformable");
    if (inplace) MR_assert(a1.stride()==a2.stride(), "stride mismatch");
    }
  static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
    const fmav_info &ar, const shape_t &axes)
    {
    sanity_check_axes(ac.ndim(), axes);
    MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
    for (size_t i=0; i<ac.ndim(); ++i)
      MR_assert(ac.shape(i)== (i==axes.back()) ? (ar.shape(i)/2+1) : ar.shape(i),
        "axis length mismatch");
    }
  static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
    const fmav_info &ar, const size_t axis)
    {
    if (axis>=ac.ndim()) throw std::invalid_argument("bad axis number");
    MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
    for (size_t i=0; i<ac.ndim(); ++i)
      MR_assert(ac.shape(i)== (i==axis) ? (ar.shape(i)/2+1) : ar.shape(i),
        "axis length mismatch");
Martin Reinecke's avatar
Martin Reinecke committed
112
113
114
    }

#ifdef MRUTIL_NO_THREADING
Martin Reinecke's avatar
Martin Reinecke committed
115
  static size_t thread_count (size_t /*nthreads*/, const fmav_info &/*info*/,
Martin Reinecke's avatar
Martin Reinecke committed
116
117
118
    size_t /*axis*/, size_t /*vlen*/)
    { return 1; }
#else
Martin Reinecke's avatar
Martin Reinecke committed
119
  static size_t thread_count (size_t nthreads, const fmav_info &info,
Martin Reinecke's avatar
Martin Reinecke committed
120
121
122
    size_t axis, size_t vlen)
    {
    if (nthreads==1) return 1;
Martin Reinecke's avatar
Martin Reinecke committed
123
124
125
    size_t size = info.size();
    size_t parallel = size / (info.shape(axis) * vlen);
    if (info.shape(axis) < 1000)
Martin Reinecke's avatar
Martin Reinecke committed
126
      parallel /= 4;
Martin Reinecke's avatar
Martin Reinecke committed
127
    size_t max_threads = (nthreads==0) ? mr::max_threads() : nthreads;
128
    return std::max(size_t(1), std::min(parallel, max_threads));
Martin Reinecke's avatar
Martin Reinecke committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    }
#endif
  };


//
// sine/cosine transforms
//

template<typename T0> class T_dct1
  {
  private:
    pocketfft_r<T0> fftplan;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
144
    MRUTIL_NOINLINE T_dct1(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
145
146
      : fftplan(2*(length-1)) {}

Martin Reinecke's avatar
Martin Reinecke committed
147
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho,
Martin Reinecke's avatar
Martin Reinecke committed
148
149
150
151
152
153
      int /*type*/, bool /*cosine*/) const
      {
      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
      size_t N=fftplan.length(), n=N/2+1;
      if (ortho)
        { c[0]*=sqrt2; c[n-1]*=sqrt2; }
154
      aligned_array<T> tmp(N);
Martin Reinecke's avatar
Martin Reinecke committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
      tmp[0] = c[0];
      for (size_t i=1; i<n; ++i)
        tmp[i] = tmp[N-i] = c[i];
      fftplan.exec(tmp.data(), fct, true);
      c[0] = tmp[0];
      for (size_t i=1; i<n; ++i)
        c[i] = tmp[2*i-1];
      if (ortho)
        { c[0]*=sqrt2*T0(0.5); c[n-1]*=sqrt2*T0(0.5); }
      }

    size_t length() const { return fftplan.length()/2+1; }
  };

template<typename T0> class T_dst1
  {
  private:
    pocketfft_r<T0> fftplan;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
175
    MRUTIL_NOINLINE T_dst1(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
176
177
      : fftplan(2*(length+1)) {}

Martin Reinecke's avatar
Martin Reinecke committed
178
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct,
Martin Reinecke's avatar
Martin Reinecke committed
179
180
181
      bool /*ortho*/, int /*type*/, bool /*cosine*/) const
      {
      size_t N=fftplan.length(), n=N/2-1;
182
      aligned_array<T> tmp(N);
Martin Reinecke's avatar
Martin Reinecke committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
      tmp[0] = tmp[n+1] = c[0]*0;
      for (size_t i=0; i<n; ++i)
        { tmp[i+1]=c[i]; tmp[N-1-i]=-c[i]; }
      fftplan.exec(tmp.data(), fct, true);
      for (size_t i=0; i<n; ++i)
        c[i] = -tmp[2*i+2];
      }

    size_t length() const { return fftplan.length()/2-1; }
  };

template<typename T0> class T_dcst23
  {
  private:
    pocketfft_r<T0> fftplan;
Martin Reinecke's avatar
Martin Reinecke committed
198
    std::vector<T0> twiddle;
Martin Reinecke's avatar
Martin Reinecke committed
199
200

  public:
Martin Reinecke's avatar
Martin Reinecke committed
201
    MRUTIL_NOINLINE T_dcst23(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
202
203
204
205
206
207
208
      : fftplan(length), twiddle(length)
      {
      UnityRoots<T0,Cmplx<T0>> tw(4*length);
      for (size_t i=0; i<length; ++i)
        twiddle[i] = tw[i+1].r;
      }

Martin Reinecke's avatar
Martin Reinecke committed
209
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho,
Martin Reinecke's avatar
Martin Reinecke committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
      int type, bool cosine) const
      {
      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
      size_t N=length();
      size_t NS2 = (N+1)/2;
      if (type==2)
        {
        if (!cosine)
          for (size_t k=1; k<N; k+=2)
            c[k] = -c[k];
        c[0] *= 2;
        if ((N&1)==0) c[N-1]*=2;
        for (size_t k=1; k<N-1; k+=2)
          MPINPLACE(c[k+1], c[k]);
        fftplan.exec(c, fct, false);
        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
          {
          T t1 = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
          T t2 = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
          c[k] = T0(0.5)*(t1+t2); c[kc]=T0(0.5)*(t1-t2);
          }
        if ((N&1)==0)
          c[NS2] *= twiddle[NS2-1];
        if (!cosine)
          for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
235
            std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
236
237
238
239
240
241
242
        if (ortho) c[0]*=sqrt2*T0(0.5);
        }
      else
        {
        if (ortho) c[0]*=sqrt2;
        if (!cosine)
          for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
243
            std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
          {
          T t1=c[k]+c[kc], t2=c[k]-c[kc];
          c[k] = twiddle[k-1]*t2+twiddle[kc-1]*t1;
          c[kc]= twiddle[k-1]*t1-twiddle[kc-1]*t2;
          }
        if ((N&1)==0)
          c[NS2] *= 2*twiddle[NS2-1];
        fftplan.exec(c, fct, true);
        for (size_t k=1; k<N-1; k+=2)
          MPINPLACE(c[k], c[k+1]);
        if (!cosine)
          for (size_t k=1; k<N; k+=2)
            c[k] = -c[k];
        }
      }

    size_t length() const { return fftplan.length(); }
  };

template<typename T0> class T_dcst4
  {
  private:
    size_t N;
Martin Reinecke's avatar
Martin Reinecke committed
268
269
    std::unique_ptr<pocketfft_c<T0>> fft;
    std::unique_ptr<pocketfft_r<T0>> rfft;
270
    aligned_array<Cmplx<T0>> C2;
Martin Reinecke's avatar
Martin Reinecke committed
271
272

  public:
Martin Reinecke's avatar
Martin Reinecke committed
273
    MRUTIL_NOINLINE T_dcst4(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
274
275
276
277
278
279
280
281
282
283
284
285
286
      : N(length),
        fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
        rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
        C2((N&1) ? 0 : N/2)
      {
      if ((N&1)==0)
        {
        UnityRoots<T0,Cmplx<T0>> tw(16*N);
        for (size_t i=0; i<N/2; ++i)
          C2[i] = conj(tw[8*i+1]);
        }
      }

Martin Reinecke's avatar
Martin Reinecke committed
287
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct,
Martin Reinecke's avatar
Martin Reinecke committed
288
289
290
291
292
      bool /*ortho*/, int /*type*/, bool cosine) const
      {
      size_t n2 = N/2;
      if (!cosine)
        for (size_t k=0, kc=N-1; k<n2; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
293
          std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
294
295
296
297
298
299
      if (N&1)
        {
        // The following code is derived from the FFTW3 function apply_re11()
        // and is released under the 3-clause BSD license with friendly
        // permission of Matteo Frigo and Steven G. Johnson.

300
        aligned_array<T> y(N);
Martin Reinecke's avatar
Martin Reinecke committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        {
        size_t i=0, m=n2;
        for (; m<N; ++i, m+=4)
          y[i] = c[m];
        for (; m<2*N; ++i, m+=4)
          y[i] = -c[2*N-m-1];
        for (; m<3*N; ++i, m+=4)
          y[i] = -c[m-2*N];
        for (; m<4*N; ++i, m+=4)
          y[i] = c[4*N-m-1];
        for (; i<N; ++i, m+=4)
          y[i] = c[m-4*N];
        }
        rfft->exec(y.data(), fct, true);
        {
        auto SGN = [](size_t i)
           {
           constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
           return (i&2) ? -sqrt2 : sqrt2;
           };
        c[n2] = y[0]*SGN(n2+1);
        size_t i=0, i1=1, k=1;
        for (; k<n2; ++i, ++i1, k+=2)
          {
          c[i    ] = y[2*k-1]*SGN(i1)     + y[2*k  ]*SGN(i);
          c[N -i1] = y[2*k-1]*SGN(N -i)   - y[2*k  ]*SGN(N -i1);
          c[n2-i1] = y[2*k+1]*SGN(n2-i)   - y[2*k+2]*SGN(n2-i1);
          c[n2+i1] = y[2*k+1]*SGN(n2+i+2) + y[2*k+2]*SGN(n2+i1);
          }
        if (k == n2)
          {
          c[i   ] = y[2*k-1]*SGN(i+1) + y[2*k]*SGN(i);
          c[N-i1] = y[2*k-1]*SGN(i+2) + y[2*k]*SGN(i1);
          }
        }

        // FFTW-derived code ends here
        }
      else
        {
        // even length algorithm from
        // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
343
        aligned_array<Cmplx<T>> y(n2);
Martin Reinecke's avatar
Martin Reinecke committed
344
345
346
347
348
349
350
351
        for(size_t i=0; i<n2; ++i)
          {
          y[i].Set(c[2*i],c[N-1-2*i]);
          y[i] *= C2[i];
          }
        fft->exec(y.data(), fct, true);
        for(size_t i=0, ic=n2-1; i<n2; ++i, --ic)
          {
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
352
353
          c[2*i  ] = T0( 2)*(y[i ].r*C2[i ].r-y[i ].i*C2[i ].i);
          c[2*i+1] = T0(-2)*(y[ic].i*C2[ic].r+y[ic].r*C2[ic].i);
Martin Reinecke's avatar
Martin Reinecke committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
          }
        }
      if (!cosine)
        for (size_t k=1; k<N; k+=2)
          c[k] = -c[k];
      }

    size_t length() const { return N; }
  };


//
// multi-D infrastructure
//

Martin Reinecke's avatar
Martin Reinecke committed
369
template<typename T> std::shared_ptr<T> get_plan(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
370
371
  {
#if POCKETFFT_CACHE_SIZE==0
Martin Reinecke's avatar
Martin Reinecke committed
372
  return std::make_shared<T>(length);
Martin Reinecke's avatar
Martin Reinecke committed
373
374
#else
  constexpr size_t nmax=POCKETFFT_CACHE_SIZE;
Martin Reinecke's avatar
Martin Reinecke committed
375
376
  static std::array<std::shared_ptr<T>, nmax> cache;
  static std::array<size_t, nmax> last_access{{0}};
Martin Reinecke's avatar
Martin Reinecke committed
377
378
  static size_t access_counter = 0;

Martin Reinecke's avatar
Martin Reinecke committed
379
  auto find_in_cache = [&]() -> std::shared_ptr<T>
Martin Reinecke's avatar
Martin Reinecke committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    {
    for (size_t i=0; i<nmax; ++i)
      if (cache[i] && (cache[i]->length()==length))
        {
        // no need to update if this is already the most recent entry
        if (last_access[i]!=access_counter)
          {
          last_access[i] = ++access_counter;
          // Guard against overflow
          if (access_counter == 0)
            last_access.fill(0);
          }
        return cache[i];
        }

    return nullptr;
    };
Martin Reinecke's avatar
Martin Reinecke committed
397
#ifdef MRUTIL_NO_THREADING
Martin Reinecke's avatar
Martin Reinecke committed
398
399
  auto p = find_in_cache();
  if (p) return p;
Martin Reinecke's avatar
Martin Reinecke committed
400
  auto plan = std::make_shared<T>(length);
Martin Reinecke's avatar
Martin Reinecke committed
401
402
403
404
405
406
407
408
409
  size_t lru = 0;
  for (size_t i=1; i<nmax; ++i)
    if (last_access[i] < last_access[lru])
      lru = i;

  cache[lru] = plan;
  last_access[lru] = ++access_counter;
  return plan;
#else
Martin Reinecke's avatar
Martin Reinecke committed
410
  static std::mutex mut;
Martin Reinecke's avatar
Martin Reinecke committed
411
412

  {
Martin Reinecke's avatar
Martin Reinecke committed
413
  std::lock_guard<std::mutex> lock(mut);
Martin Reinecke's avatar
Martin Reinecke committed
414
415
416
  auto p = find_in_cache();
  if (p) return p;
  }
Martin Reinecke's avatar
Martin Reinecke committed
417
  auto plan = std::make_shared<T>(length);
Martin Reinecke's avatar
Martin Reinecke committed
418
  {
Martin Reinecke's avatar
Martin Reinecke committed
419
  std::lock_guard<std::mutex> lock(mut);
Martin Reinecke's avatar
Martin Reinecke committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
  auto p = find_in_cache();
  if (p) return p;

  size_t lru = 0;
  for (size_t i=1; i<nmax; ++i)
    if (last_access[i] < last_access[lru])
      lru = i;

  cache[lru] = plan;
  last_access[lru] = ++access_counter;
  }
  return plan;
#endif
#endif
  }

template<size_t N> class multi_iter
  {
  private:
Martin Reinecke's avatar
Martin Reinecke committed
439
440
441
442
443
    shape_t shp, pos;
    stride_t str_i, str_o;
    size_t cshp_i, cshp_o, rem;
    ptrdiff_t cstr_i, cstr_o, sstr_i, sstr_o, p_ii, p_i[N], p_oi, p_o[N];
    bool uni_i, uni_o;
Martin Reinecke's avatar
Martin Reinecke committed
444
445
446

    void advance_i()
      {
Martin Reinecke's avatar
Martin Reinecke committed
447
      for (size_t i=0; i<pos.size(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
448
        {
Martin Reinecke's avatar
Martin Reinecke committed
449
450
451
        p_ii += str_i[i];
        p_oi += str_o[i];
        if (++pos[i] < shp[i])
Martin Reinecke's avatar
Martin Reinecke committed
452
453
          return;
        pos[i] = 0;
Martin Reinecke's avatar
Martin Reinecke committed
454
455
        p_ii -= ptrdiff_t(shp[i])*str_i[i];
        p_oi -= ptrdiff_t(shp[i])*str_o[i];
Martin Reinecke's avatar
Martin Reinecke committed
456
457
458
459
        }
      }

  public:
Martin Reinecke's avatar
Martin Reinecke committed
460
461
462
    multi_iter(const fmav_info &iarr, const fmav_info &oarr, size_t idim,
      size_t nshares, size_t myshare)
      : rem(iarr.size()/iarr.shape(idim)), sstr_i(0), sstr_o(0), p_ii(0), p_oi(0)
Martin Reinecke's avatar
Martin Reinecke committed
463
      {
Martin Reinecke's avatar
Martin Reinecke committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
      MR_assert(oarr.ndim()==iarr.ndim(), "dimension mismatch");
      MR_assert(iarr.ndim()>=1, "not enough dimensions");
      // Sort the extraneous dimensions in order of ascending output stride;
      // this should improve overall cache re-use and avoid clashes between
      // threads as much as possible.
      shape_t idx(iarr.ndim());
      std::iota(idx.begin(), idx.end(), 0);
      sort(idx.begin(), idx.end(),
        [&oarr](size_t i1, size_t i2) {return oarr.stride(i1) < oarr.stride(i2);});
      for (auto i: idx)
        if (i!=idim)
          {
          pos.push_back(0);
          MR_assert(iarr.shape(i)==oarr.shape(i), "shape mismatch");
          shp.push_back(iarr.shape(i));
          str_i.push_back(iarr.stride(i));
          str_o.push_back(oarr.stride(i));
          }
      MR_assert(idim<iarr.ndim(), "bad active dimension");
      cstr_i = iarr.stride(idim);
      cstr_o = oarr.stride(idim);
      cshp_i = iarr.shape(idim);
      cshp_o = oarr.shape(idim);

// collapse unneeded dimensions
      bool done = false;
      while(!done)
        {
        done=true;
        for (size_t i=1; i<shp.size(); ++i)
          if ((str_i[i] == str_i[i-1]*ptrdiff_t(shp[i-1]))
           && (str_o[i] == str_o[i-1]*ptrdiff_t(shp[i-1])))
            {
            shp[i-1] *= shp[i];
            str_i.erase(str_i.begin()+ptrdiff_t(i));
            str_o.erase(str_o.begin()+ptrdiff_t(i));
            shp.erase(shp.begin()+ptrdiff_t(i));
            pos.pop_back();
            done=false;
            }
        }
      if (pos.size()>0)
        {
        sstr_i = str_i[0];
        sstr_o = str_o[0];
        }

Martin Reinecke's avatar
Martin Reinecke committed
511
      if (nshares==1) return;
Martin Reinecke's avatar
Martin Reinecke committed
512
513
      if (nshares==0) throw std::runtime_error("can't run with zero threads");
      if (myshare>=nshares) throw std::runtime_error("impossible share requested");
Martin Reinecke's avatar
Martin Reinecke committed
514
515
516
517
518
519
520
      size_t nbase = rem/nshares;
      size_t additional = rem%nshares;
      size_t lo = myshare*nbase + ((myshare<additional) ? myshare : additional);
      size_t hi = lo+nbase+(myshare<additional);
      size_t todo = hi-lo;

      size_t chunk = rem;
Martin Reinecke's avatar
Martin Reinecke committed
521
      for (size_t i2=0, i=pos.size()-1; i2<pos.size(); ++i2,--i)
Martin Reinecke's avatar
Martin Reinecke committed
522
        {
Martin Reinecke's avatar
Martin Reinecke committed
523
        chunk /= shp[i];
Martin Reinecke's avatar
Martin Reinecke committed
524
525
        size_t n_advance = lo/chunk;
        pos[i] += n_advance;
Martin Reinecke's avatar
Martin Reinecke committed
526
527
        p_ii += ptrdiff_t(n_advance)*str_i[i];
        p_oi += ptrdiff_t(n_advance)*str_o[i];
Martin Reinecke's avatar
Martin Reinecke committed
528
529
        lo -= n_advance*chunk;
        }
Martin Reinecke's avatar
Martin Reinecke committed
530
      MR_assert(lo==0, "must not happen");
Martin Reinecke's avatar
Martin Reinecke committed
531
532
533
534
      rem = todo;
      }
    void advance(size_t n)
      {
Martin Reinecke's avatar
Martin Reinecke committed
535
      if (rem<n) throw std::runtime_error("underrun");
Martin Reinecke's avatar
Martin Reinecke committed
536
537
538
539
540
541
      for (size_t i=0; i<n; ++i)
        {
        p_i[i] = p_ii;
        p_o[i] = p_oi;
        advance_i();
        }
Martin Reinecke's avatar
Martin Reinecke committed
542
543
544
545
546
547
      uni_i = uni_o = true;
      for (size_t i=1; i<n; ++i)
        {
        uni_i = uni_i && (p_i[i]-p_i[i-1] == sstr_i);
        uni_o = uni_o && (p_o[i]-p_o[i-1] == sstr_o);
        }
Martin Reinecke's avatar
Martin Reinecke committed
548
549
      rem -= n;
      }
Martin Reinecke's avatar
Martin Reinecke committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t iofs_uni(size_t j, size_t i) const { return p_i[0] + ptrdiff_t(j)*sstr_i + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i)*cstr_o; }
    ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i)*cstr_o; }
    ptrdiff_t oofs_uni(size_t j, size_t i) const { return p_o[0] + ptrdiff_t(j)*sstr_o + ptrdiff_t(i)*cstr_o; }
    bool uniform_i() const { return uni_i; } 
    ptrdiff_t unistride_i() const { return sstr_i; } 
    bool uniform_o() const { return uni_o; } 
    ptrdiff_t unistride_o() const { return sstr_o; } 
    size_t length_in() const { return cshp_i; }
    size_t length_out() const { return cshp_o; }
    ptrdiff_t stride_in() const { return cstr_i; }
    ptrdiff_t stride_out() const { return cstr_o; }
Martin Reinecke's avatar
Martin Reinecke committed
564
565
566
567
568
569
570
    size_t remaining() const { return rem; }
  };

class rev_iter
  {
  private:
    shape_t pos;
Martin Reinecke's avatar
Martin Reinecke committed
571
    fmav_info arr;
Martin Reinecke's avatar
Martin Reinecke committed
572
573
    std::vector<char> rev_axis;
    std::vector<char> rev_jump;
Martin Reinecke's avatar
Martin Reinecke committed
574
575
576
577
578
579
    size_t last_axis, last_size;
    shape_t shp;
    ptrdiff_t p, rp;
    size_t rem;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
580
    rev_iter(const fmav_info &arr_, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
      : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
        rev_jump(arr_.ndim(), 1), p(0), rp(0)
      {
      for (auto ax: axes)
        rev_axis[ax]=1;
      last_axis = axes.back();
      last_size = arr.shape(last_axis)/2 + 1;
      shp = arr.shape();
      shp[last_axis] = last_size;
      rem=1;
      for (auto i: shp)
        rem *= i;
      }
    void advance()
      {
      --rem;
      for (int i_=int(pos.size())-1; i_>=0; --i_)
        {
        auto i = size_t(i_);
        p += arr.stride(i);
        if (!rev_axis[i])
          rp += arr.stride(i);
        else
          {
          rp -= arr.stride(i);
          if (rev_jump[i])
            {
            rp += ptrdiff_t(arr.shape(i))*arr.stride(i);
            rev_jump[i] = 0;
            }
          }
        if (++pos[i] < shp[i])
          return;
        pos[i] = 0;
        p -= ptrdiff_t(shp[i])*arr.stride(i);
        if (rev_axis[i])
          {
          rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i);
          rev_jump[i] = 1;
          }
        else
          rp -= ptrdiff_t(shp[i])*arr.stride(i);
        }
      }
    ptrdiff_t ofs() const { return p; }
    ptrdiff_t rev_ofs() const { return rp; }
    size_t remaining() const { return rem; }
  };

Martin Reinecke's avatar
Martin Reinecke committed
630
631
template<typename T, typename T0> aligned_array<T> alloc_tmp
  (const fmav_info &info, size_t axsize)
Martin Reinecke's avatar
Martin Reinecke committed
632
  {
Martin Reinecke's avatar
Martin Reinecke committed
633
634
635
  auto othersize = info.size()/axsize;
  constexpr auto vlen = native_simd<T0>::size();
  auto tmpsize = axsize*((othersize>=vlen) ? vlen : 1);
Martin Reinecke's avatar
Martin Reinecke committed
636
  return aligned_array<T>(tmpsize);
Martin Reinecke's avatar
Martin Reinecke committed
637
638
  }

639
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
Martin Reinecke committed
640
641
642
643
644
645
646
647
  const fmav<Cmplx<T>> &src, Cmplx<native_simd<T>> *MRUTIL_RESTRICT dst)
  {
  if (it.uniform_i())
    {
    auto ptr = &src[it.iofs_uni(0,0)];
    auto jstr = it.unistride_i();
    auto istr = it.stride_in();
    if (istr==1)
648
      for (size_t i=0; i<it.length_in(); ++i)
649
650
        {
        Cmplx<native_simd<T>> stmp;
651
        for (size_t j=0; j<vlen; ++j)
652
          {
653
          auto tmp = ptr[ptrdiff_t(j)*jstr+ptrdiff_t(i)];
654
655
656
657
658
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
Martin Reinecke's avatar
Martin Reinecke committed
659
    else if (jstr==1)
660
      for (size_t i=0; i<it.length_in(); ++i)
661
662
        {
        Cmplx<native_simd<T>> stmp;
663
        for (size_t j=0; j<vlen; ++j)
664
          {
665
          auto tmp = ptr[ptrdiff_t(j)+ptrdiff_t(i)*istr];
666
667
668
669
670
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
Martin Reinecke's avatar
Martin Reinecke committed
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
    else
      for (size_t i=0; i<it.length_in(); ++i)
        {
        Cmplx<native_simd<T>> stmp;
        for (size_t j=0; j<vlen; ++j)
          {
          auto tmp = src[it.iofs_uni(j,i)];
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
    }
  else
    for (size_t i=0; i<it.length_in(); ++i)
686
      {
Martin Reinecke's avatar
Martin Reinecke committed
687
688
689
690
691
692
693
694
      Cmplx<native_simd<T>> stmp;
      for (size_t j=0; j<vlen; ++j)
        {
        auto tmp = src[it.iofs(j,i)];
        stmp.r[j] = tmp.r;
        stmp.i[j] = tmp.i;
        }
      dst[i] = stmp;
695
      }
Martin Reinecke's avatar
Martin Reinecke committed
696
697
  }

698
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
699
  const fmav<T> &src, native_simd<T> *MRUTIL_RESTRICT dst)
Martin Reinecke's avatar
Martin Reinecke committed
700
  {
Martin Reinecke's avatar
Martin Reinecke committed
701
  if (it.uniform_i())
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    {
    auto ptr = &src[it.iofs_uni(0,0)];
    auto jstr = it.unistride_i();
    auto istr = it.stride_in();
    if (istr==1)
      for (size_t i=0; i<it.length_in(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst[i][j] = ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i)];
    else if (jstr==1)
      for (size_t i=0; i<it.length_in(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst[i][j] = ptr[ptrdiff_t(j) + ptrdiff_t(i)*istr];
    else
      for (size_t i=0; i<it.length_in(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst[i][j] = src[it.iofs_uni(j,i)];
    }
Martin Reinecke's avatar
Martin Reinecke committed
719
  else
720
    for (size_t i=0; i<it.length_in(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
721
722
      for (size_t j=0; j<vlen; ++j)
        dst[i][j] = src[it.iofs(j,i)];
Martin Reinecke's avatar
Martin Reinecke committed
723
724
  }

725
template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
726
  const fmav<T> &src, T *MRUTIL_RESTRICT dst)
Martin Reinecke's avatar
Martin Reinecke committed
727
728
729
730
731
732
  {
  if (dst == &src[it.iofs(0)]) return;  // in-place
  for (size_t i=0; i<it.length_in(); ++i)
    dst[i] = src[it.iofs(i)];
  }

733
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
734
  const Cmplx<native_simd<T>> *MRUTIL_RESTRICT src, fmav<Cmplx<T>> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
735
  {
Martin Reinecke's avatar
Martin Reinecke committed
736
  if (it.uniform_o())
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
    {
    auto ptr = &dst.vraw(it.oofs_uni(0,0));
    auto jstr = it.unistride_o();
    auto istr = it.stride_out();
    if (istr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i)].Set(src[i].r[j],src[i].i[j]);
    else if (jstr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j) + ptrdiff_t(i)*istr].Set(src[i].r[j],src[i].i[j]);
    else
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst.vraw(it.oofs_uni(j,i)).Set(src[i].r[j],src[i].i[j]);
    }
Martin Reinecke's avatar
Martin Reinecke committed
754
  else
755
756
    {
    auto ptr = dst.vdata();
757
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
758
759
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]);
760
    }
Martin Reinecke's avatar
Martin Reinecke committed
761
762
  }

763
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
764
  const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
765
  {
Martin Reinecke's avatar
Martin Reinecke committed
766
  if (it.uniform_o())
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
    {
    auto ptr = &dst.vraw(it.oofs_uni(0,0));
    auto jstr = it.unistride_o();
    auto istr = it.stride_out();
    if (istr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i)] = src[i][j];
    else if (jstr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j) + ptrdiff_t(i)*istr] = src[i][j];
    else
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst.vraw(it.oofs_uni(j,i)) = src[i][j];
    }
Martin Reinecke's avatar
Martin Reinecke committed
784
  else
785
786
    {
    auto ptr=dst.vdata();
787
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
788
789
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i)] = src[i][j];
790
    }
Martin Reinecke's avatar
Martin Reinecke committed
791
792
  }

793
template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
794
  const T *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
795
  {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
796
  auto ptr=dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
797
798
  if (src == &dst[it.oofs(0)]) return;  // in-place
  for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
799
    ptr[it.oofs(i)] = src[i];
Martin Reinecke's avatar
Martin Reinecke committed
800
801
  }

Martin Reinecke's avatar
Martin Reinecke committed
802
template <typename T> struct add_vec { using type = native_simd<T>; };
Martin Reinecke's avatar
Martin Reinecke committed
803
template <typename T> struct add_vec<Cmplx<T>>
Martin Reinecke's avatar
Martin Reinecke committed
804
  { using type = Cmplx<native_simd<T>>; };
Martin Reinecke's avatar
Martin Reinecke committed
805
806
807
template <typename T> using add_vec_t = typename add_vec<T>::type;

template<typename Tplan, typename T, typename T0, typename Exec>
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
808
MRUTIL_NOINLINE void general_nd(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
809
810
811
  const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
  const bool allow_inplace=true)
  {
Martin Reinecke's avatar
Martin Reinecke committed
812
  std::shared_ptr<Tplan> plan;
Martin Reinecke's avatar
Martin Reinecke committed
813
814
815
816
817
818
819
820

  for (size_t iax=0; iax<axes.size(); ++iax)
    {
    size_t len=in.shape(axes[iax]);
    if ((!plan) || (len!=plan->length()))
      plan = get_plan<Tplan>(len);

    execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
821
      util::thread_count(nthreads, in, axes[iax], native_simd<T0>::size()),
Martin Reinecke's avatar
rework    
Martin Reinecke committed
822
      [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
823
824
        constexpr auto vlen = native_simd<T0>::size();
        auto storage = alloc_tmp<T,T0>(in, len);
Martin Reinecke's avatar
Martin Reinecke committed
825
        const auto &tin(iax==0? in : out);
Martin Reinecke's avatar
rework    
Martin Reinecke committed
826
        multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
827
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
828
829
830
831
832
833
834
835
836
837
838
        if (vlen>1)
          while (it.remaining()>=vlen)
            {
            it.advance(vlen);
            auto tdatav = reinterpret_cast<add_vec_t<T> *>(storage.data());
            exec(it, tin, out, tdatav, *plan, fct);
            }
#endif
        while (it.remaining()>0)
          {
          it.advance(1);
Martin Reinecke's avatar
Martin Reinecke committed
839
          auto buf = allow_inplace && it.stride_out() == 1 ?
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
840
            &out.vraw(it.oofs(0)) : reinterpret_cast<T *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
841
842
843
844
845
846
847
848
849
850
851
          exec(it, tin, out, buf, *plan, fct);
          }
      });  // end of parallel region
    fct = T0(1); // factor has been applied, use 1 for remaining axes
    }
  }

struct ExecC2C
  {
  bool forward;

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
852
853
854
  template <typename T0, typename T, size_t vlen> void operator() (
    const multi_iter<vlen> &it, const fmav<Cmplx<T0>> &in,
    fmav<Cmplx<T0>> &out, T *buf, const pocketfft_c<T0> &plan, T0 fct) const
Martin Reinecke's avatar
Martin Reinecke committed
855
856
857
858
859
860
861
862
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, forward);
    copy_output(it, buf, out);
    }
  };

template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
863
  const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
864
  {
Martin Reinecke's avatar
Martin Reinecke committed
865
866
867
868
869
870
  if (it.uniform_o())
    {
    auto ptr = &dst.vraw(it.oofs_uni(0,0));
    auto jstr = it.unistride_o();
    auto istr = it.stride_out();
    if (istr==1)
Martin Reinecke's avatar
Martin Reinecke committed
871
      {
Martin Reinecke's avatar
Martin Reinecke committed
872
873
874
875
876
877
878
879
880
881
882
883
      for (size_t j=0; j<vlen; ++j)
        ptr[ptrdiff_t(j)*jstr] = src[0][j];
      size_t i=1, i1=1, i2=it.length_out()-1;
      for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
        for (size_t j=0; j<vlen; ++j)
          {
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)] = src[i][j]+src[i+1][j];
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i2)] = src[i][j]-src[i+1][j];
          }
      if (i<it.length_out())
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)] = src[i][j];
Martin Reinecke's avatar
Martin Reinecke committed
884
      }
Martin Reinecke's avatar
Martin Reinecke committed
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
    else if (jstr==1)
      {
      for (size_t j=0; j<vlen; ++j)
        ptr[ptrdiff_t(j)] = src[0][j];
      size_t i=1, i1=1, i2=it.length_out()-1;
      for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
        for (size_t j=0; j<vlen; ++j)
          {
          ptr[ptrdiff_t(j) + ptrdiff_t(i1)*istr] = src[i][j]+src[i+1][j];
          ptr[ptrdiff_t(j) + ptrdiff_t(i2)*istr] = src[i][j]-src[i+1][j];
          }
      if (i<it.length_out())
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j) + ptrdiff_t(i1)*istr] = src[i][j];
      }
    else
      {
      for (size_t j=0; j<vlen; ++j)
        ptr[ptrdiff_t(j)*jstr] = src[0][j];
      size_t i=1, i1=1, i2=it.length_out()-1;
      for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
        for (size_t j=0; j<vlen; ++j)
          {
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)*istr] = src[i][j]+src[i+1][j];
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i2)*istr] = src[i][j]-src[i+1][j];
          }
      if (i<it.length_out())
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)*istr] = src[i][j];
      }
    }
  else
    {
    auto ptr = dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
919
    for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
Martin Reinecke committed
920
921
922
923
924
925
926
927
928
929
930
931
      ptr[it.oofs(j,0)] = src[0][j];
    size_t i=1, i1=1, i2=it.length_out()-1;
    for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
      for (size_t j=0; j<vlen; ++j)
        {
        ptr[it.oofs(j,i1)] = src[i][j]+src[i+1][j];
        ptr[it.oofs(j,i2)] = src[i][j]-src[i+1][j];
        }
    if (i<it.length_out())
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i1)] = src[i][j];
    }
Martin Reinecke's avatar
Martin Reinecke committed
932
933
934
  }

template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
935
  const T *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
936
  {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
937
938
  auto ptr = dst.vdata();
  ptr[it.oofs(0)] = src[0];
Martin Reinecke's avatar
Martin Reinecke committed
939
940
941
  size_t i=1, i1=1, i2=it.length_out()-1;
  for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
    {
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
942
943
    ptr[it.oofs(i1)] = src[i]+src[i+1];
    ptr[it.oofs(i2)] = src[i]-src[i+1];
Martin Reinecke's avatar
Martin Reinecke committed
944
945
    }
  if (i<it.length_out())
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
946
    ptr[it.oofs(i1)] = src[i];
Martin Reinecke's avatar
Martin Reinecke committed
947
948
949
950
951
  }

struct ExecHartley
  {
  template <typename T0, typename T, size_t vlen> void operator () (
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
952
    const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out,
Martin Reinecke's avatar
Martin Reinecke committed
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
    T * buf, const pocketfft_r<T0> &plan, T0 fct) const
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, true);
    copy_hartley(it, buf, out);
    }
  };

struct ExecDcst
  {
  bool ortho;
  int type;
  bool cosine;

  template <typename T0, typename T, typename Tplan, size_t vlen>
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
968
969
  void operator () (const multi_iter<vlen> &it, const fmav<T0> &in,
    fmav <T0> &out, T * buf, const Tplan &plan, T0 fct) const
Martin Reinecke's avatar
Martin Reinecke committed
970
971
972
973
974
975
976
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, ortho, type, cosine);
    copy_output(it, buf, out);
    }
  };

Martin Reinecke's avatar
Martin Reinecke committed
977
template<typename T> MRUTIL_NOINLINE void general_r2c(
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
978
  const fmav<T> &in, fmav<Cmplx<T>> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
979
980
981
982
983
  size_t nthreads)
  {
  auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
  size_t len=in.shape(axis);
  execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
984
    util::thread_count(nthreads, in, axis, native_simd<T>::size()),
Martin Reinecke's avatar
rework    
Martin Reinecke committed
985
    [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
986
987
    constexpr auto vlen = native_simd<T>::size();
    auto storage = alloc_tmp<T,T>(in, len);
Martin Reinecke's avatar
rework    
Martin Reinecke committed
988
    multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
989
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
990
991
992
993
    if (vlen>1)
      while (it.remaining()>=vlen)
        {
        it.advance(vlen);
Martin Reinecke's avatar
Martin Reinecke committed
994
        auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
995
996
        copy_input(it, in, tdatav);
        plan->exec(tdatav, fct, true);
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
997
        auto vout = out.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
998
        for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
999
          vout[it.oofs(j,0)].Set(tdatav[0][j]);
Martin Reinecke's avatar
Martin Reinecke committed
1000
1001
1002
1003
        size_t i=1, ii=1;
        if (forward)
          for (; i<len-1; i+=2, ++ii)
            for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1004
              vout[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
Martin Reinecke's avatar
Martin Reinecke committed
1005
1006
1007
        else
          for (; i<len-1; i+=2, ++ii)
            for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1008
              vout[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
Martin Reinecke's avatar
Martin Reinecke committed
1009
1010
        if (i<len)
          for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1011
            vout[it.oofs(j,ii)].Set(tdatav[i][j]);
Martin Reinecke's avatar
Martin Reinecke committed
1012
1013
1014
1015
1016
1017
1018
1019
        }
#endif
    while (it.remaining()>0)
      {
      it.advance(1);
      auto tdata = reinterpret_cast<T *>(storage.data());
      copy_input(it, in, tdata);
      plan->exec(tdata, fct, true);
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1020
1021
      auto vout = out.vdata();
      vout[it.oofs(0)].Set(tdata[0]);
Martin Reinecke's avatar
Martin Reinecke committed
1022
1023
1024
      size_t i=1, ii=1;
      if (forward)
        for (; i<len-1; i+=2, ++ii)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1025
          vout[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
Martin Reinecke's avatar
Martin Reinecke committed
1026
1027
      else
        for (; i<len-1; i+=2, ++ii)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1028
          vout[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
Martin Reinecke's avatar
Martin Reinecke committed
1029
      if (i<len)
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1030
        vout[it.oofs(ii)].Set(tdata[i]);
Martin Reinecke's avatar
Martin Reinecke committed
1031
1032
1033
      }
    });  // end of parallel region
  }
Martin Reinecke's avatar
Martin Reinecke committed
1034
template<typename T> MRUTIL_NOINLINE void general_c2r(
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1035
  const fmav<Cmplx<T>> &in, fmav<T> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
1036
1037
1038
1039
1040
  size_t nthreads)
  {
  auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
  size_t len=out.shape(axis);
  execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
1041
    util::thread_count(nthreads, in, axis, native_simd<T>::size()),
Martin Reinecke's avatar
rework    
Martin Reinecke committed
1042
    [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
1043
1044
      constexpr auto vlen = native_simd<T>::size();
      auto storage = alloc_tmp<T,T>(out, len);
Martin Reinecke's avatar
rework    
Martin Reinecke committed
1045
      multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
1046
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
1047
1048
1049
1050
      if (vlen>1)
        while (it.remaining()>=vlen)
          {
          it.advance(vlen);
Martin Reinecke's avatar
Martin Reinecke committed
1051
          auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
1052
1053
1054
1055
1056
1057
1058
          for (size_t j=0; j<vlen; ++j)
            tdatav[0][j]=in[it.iofs(j,0)].r;
          {
          size_t i=1, ii=1;
          if (forward)
            for (; i<len-1; i+=2, ++ii)
              for (size_t j=0; j<vlen; ++j)
1059
1060
1061
1062
                {
                tdatav[i  ][j] =  in[it.iofs(j,ii)].r;
                tdatav[i+1][j] = -in[it.iofs(j,ii)].i;
                }
Martin Reinecke's avatar
Martin Reinecke committed
1063
1064
1065
          else
            for (; i<len-1; i+=2, ++ii)
              for (size_t j=0; j<vlen; ++j)
1066
1067
1068
1069
                {
                tdatav[i  ][j] = in[it.iofs(j,ii)].r;
                tdatav[i+1][j] = in[it.iofs(j,ii)].i;
                }
Martin Reinecke's avatar
Martin Reinecke committed
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
          if (i<len)
            for (size_t j=0; j<vlen; ++j)
              tdatav[i][j] = in[it.iofs(j,ii)].r;
          }
          plan->exec(tdatav, fct, false);
          copy_output(it, tdatav, out);
          }
#endif
      while (it.remaining()>0)
        {
        it.advance(1);
        auto tdata = reinterpret_cast<T *>(storage.data());
        tdata[0]=in[it.iofs(0)].r;
        {
        size_t i=1, ii=1;
        if (forward)
          for (; i<len-1; i+=2, ++ii)
1087
1088
1089
1090
            {
            tdata[i  ] =  in[it.iofs(ii)].r;
            tdata[i+1] = -in[it.iofs(ii)].i;
            }
Martin Reinecke's avatar
Martin Reinecke committed
1091
1092
        else
          for (; i<len-1; i+=2, ++ii)
1093
1094
1095
1096
            {
            tdata[i  ] = in[it.iofs(ii)].r;
            tdata[i+1] = in[it.iofs(ii)].i;
            }
Martin Reinecke's avatar
Martin Reinecke committed
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        if (i<len)
          tdata[i] = in[it.iofs(ii)].r;
        }
        plan->exec(tdata, fct, false);
        copy_output(it, tdata, out);
        }
    });  // end of parallel region
  }

struct ExecR2R
  {
  bool r2c, forward;

  template <typename T0, typename T, size_t vlen> void operator () (
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1111
    const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out, T * buf,
Martin Reinecke's avatar
Martin Reinecke committed
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
    const pocketfft_r<T0> &plan, T0 fct) const
    {
    copy_input(it, in, buf);
    if ((!r2c) && forward)
      for (size_t i=2; i<it.length_out(); i+=2)
        buf[i] = -buf[i];
    plan.exec(buf, fct, forward);
    if (r2c && (!forward))
      for (size_t i=2; i<it.length_out(); i+=2)
        buf[i] = -buf[i];
    copy_output(it, buf, out);
    }
  };

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1126
1127
template<typename T> void c2c(const fmav<std::complex<T>> &in,
  fmav<std::complex<T>> &out, const shape_t &axes, bool forward,
Martin Reinecke's avatar
Martin Reinecke committed
1128
  T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1129
  {
Martin Reinecke's avatar
Martin Reinecke committed
1130
1131
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1132
  fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1133
  fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
Martin Reinecke's avatar
Martin Reinecke committed
1134
  general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward});
Martin Reinecke's avatar
Martin Reinecke committed
1135
1136
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1137
template<typename T> void dct(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
1138
  const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1139
  {
Martin Reinecke's avatar
Martin Reinecke committed
1140
  if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
Martin Reinecke's avatar
Martin Reinecke committed
1141
1142
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
Martin Reinecke's avatar
Martin Reinecke committed
1143
1144
  const ExecDcst exec{ortho, type, true};
  if (type==1)
Martin Reinecke's avatar
Martin Reinecke committed
1145
    general_nd<T_dct1<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1146
  else if (type==4)
Martin Reinecke's avatar
Martin Reinecke committed
1147
    general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1148
  else
Martin Reinecke's avatar
Martin Reinecke committed
1149
    general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1150
1151
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1152
template<typename T> void dst(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
1153
  const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1154
  {
Martin Reinecke's avatar
Martin Reinecke committed
1155
  if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
Martin Reinecke's avatar
Martin Reinecke committed
1156
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
Martin Reinecke's avatar
Martin Reinecke committed
1157
1158
  const ExecDcst exec{ortho, type, false};
  if (type==1)
Martin Reinecke's avatar
Martin Reinecke committed
1159
    general_nd<T_dst1<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1160
  else if (type==4)
Martin Reinecke's avatar
Martin Reinecke committed
1161
    general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1162
  else
Martin Reinecke's avatar
Martin Reinecke committed
1163
    general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1164
1165
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1166
1167
template<typename T> void r2c(const fmav<T> &in,
  fmav<std::complex<T>> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
1168
1169
  size_t nthreads=1)
  {
Martin Reinecke's avatar
Martin Reinecke committed
1170
1171
  util::sanity_check_cr(out, in, axis);
  if (in.size()==0) return;
Martin Reinecke's avatar
stage 2    
Martin Reinecke committed
1172
  fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
Martin Reinecke's avatar
Martin Reinecke committed
1173
  general_r2c(in, out2, axis, forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1174
1175
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1176
1177
template<typename T> void r2c(const fmav<T> &in,
  fmav<std::complex<T>> &out, const shape_t &axes,
Martin Reinecke's avatar
Martin Reinecke committed
1178
  bool forward, T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1179
  {
Martin Reinecke's avatar
Martin Reinecke committed
1180
1181
1182
  util::sanity_check_cr(out, in, axes);
  if (in.size()==0) return;
  r2c(in, out, axes.back(), forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1183
1184
1185
  if (axes.size()==1) return;

  auto newaxes = shape_t{axes.begin(), --axes.end()};
Martin Reinecke's avatar
Martin Reinecke committed
1186
  c2c(out, out, newaxes, forward, T(1), nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1187
1188
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1189
1190
template<typename T> void c2r(const fmav<std::complex<T>> &in,
  fmav<T> &out,  size_t axis, bool forward, T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1191
  {
Martin Reinecke's avatar
Martin Reinecke committed
1192
1193
  util::sanity_check_cr(in, out, axis);
  if (in.size()==0) return;
Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1194
  fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
Martin Reinecke's avatar
Martin Reinecke committed
1195
  general_c2r(in2, out, axis, forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1196
1197
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1198
1199
template<typename T> void c2r(const fmav<std::complex<T>> &in,
  fmav<T> &out, const shape_t &axes, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
1200
1201
1202
  size_t nthreads=1)
  {
  if (axes.size()==1)
Martin Reinecke's avatar
Martin Reinecke committed
1203
1204
1205
    return c2r(in, out, axes[0], forward, fct, nthreads);
  util::sanity_check_cr(in, out, axes);
  if (in.size()==0) return;
1206
  fmav<std::complex<T>> atmp(in.shape());
1207
  auto newaxes = shape_t{axes.begin(), --axes.end()};
Martin Reinecke's avatar
Martin Reinecke committed
1208
1209
  c2c(in, atmp, newaxes, forward, T(1), nthreads);
  c2r(atmp, out, axes.back(), forward, fct, nthreads);
Martin Reinecke's avatar
Martin Reinecke committed
1210
1211
  }

Martin Reinecke's avatar
stage 1    
Martin Reinecke committed
1212
1213
template<typename T> void r2r_fftpack(const fmav<T> &in,
  fmav<T> &out, const shape_t &axes, bool real2hermitian, bool forward,
Martin Reinecke's avatar
Martin Reinecke committed
1214
  T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1215
  {