fft.h 37.3 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3
/*
This file is part of pocketfft.

Martin Reinecke's avatar
Martin Reinecke committed
4
Copyright (C) 2010-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
Copyright (C) 2019 Peter Bell

For the odd-sized DCT-IV transforms:
  Copyright (C) 2003, 2007-14 Matteo Frigo
  Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology

Authors: Martin Reinecke, Peter Bell

All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
  list of conditions and the following disclaimer in the documentation and/or
  other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its contributors may
  be used to endorse or promote products derived from this software without
  specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#ifndef MRUTIL_FFT_H
#define MRUTIL_FFT_H
41

42
#include "mr_util/math/fft1d.h"
Martin Reinecke's avatar
Martin Reinecke committed
43

Martin Reinecke's avatar
Martin Reinecke committed
44 45 46 47 48 49
#include <cmath>
#include <cstdlib>
#include <stdexcept>
#include <memory>
#include <vector>
#include <complex>
50
#include <algorithm>
51 52 53
#include "mr_util/infra/threading.h"
#include "mr_util/infra/simd.h"
#include "mr_util/infra/mav.h"
Martin Reinecke's avatar
Martin Reinecke committed
54 55 56
#ifndef MRUTIL_NO_THREADING
#include <mutex>
#endif
Martin Reinecke's avatar
Martin Reinecke committed
57 58 59 60 61

namespace mr {

namespace detail_fft {

Martin Reinecke's avatar
Martin Reinecke committed
62
using shape_t=fmav_info::shape_t;
Martin Reinecke's avatar
Martin Reinecke committed
63
using stride_t=fmav_info::stride_t;
Martin Reinecke's avatar
Martin Reinecke committed
64

Martin Reinecke's avatar
Martin Reinecke committed
65 66 67 68 69
constexpr bool FORWARD  = true,
               BACKWARD = false;

struct util // hack to avoid duplicate symbols
  {
Martin Reinecke's avatar
Martin Reinecke committed
70
  static void sanity_check_axes(size_t ndim, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
71 72
    {
    shape_t tmp(ndim,0);
Martin Reinecke's avatar
Martin Reinecke committed
73
    if (axes.empty()) throw std::invalid_argument("no axes specified");
Martin Reinecke's avatar
Martin Reinecke committed
74 75
    for (auto ax : axes)
      {
Martin Reinecke's avatar
Martin Reinecke committed
76 77
      if (ax>=ndim) throw std::invalid_argument("bad axis number");
      if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly");
Martin Reinecke's avatar
Martin Reinecke committed
78 79 80
      }
    }

Martin Reinecke's avatar
Martin Reinecke committed
81 82
  static MRUTIL_NOINLINE void sanity_check_onetype(const fmav_info &a1,
    const fmav_info &a2, bool inplace, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
83
    {
Martin Reinecke's avatar
Martin Reinecke committed
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    sanity_check_axes(a1.ndim(), axes);
    MR_assert(a1.conformable(a2), "array sizes are not conformable");
    if (inplace) MR_assert(a1.stride()==a2.stride(), "stride mismatch");
    }
  static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
    const fmav_info &ar, const shape_t &axes)
    {
    sanity_check_axes(ac.ndim(), axes);
    MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
    for (size_t i=0; i<ac.ndim(); ++i)
      MR_assert(ac.shape(i)== (i==axes.back()) ? (ar.shape(i)/2+1) : ar.shape(i),
        "axis length mismatch");
    }
  static MRUTIL_NOINLINE void sanity_check_cr(const fmav_info &ac,
    const fmav_info &ar, const size_t axis)
    {
    if (axis>=ac.ndim()) throw std::invalid_argument("bad axis number");
    MR_assert(ac.ndim()==ar.ndim(), "dimension mismatch");
    for (size_t i=0; i<ac.ndim(); ++i)
      MR_assert(ac.shape(i)== (i==axis) ? (ar.shape(i)/2+1) : ar.shape(i),
        "axis length mismatch");
Martin Reinecke's avatar
Martin Reinecke committed
105 106 107
    }

#ifdef MRUTIL_NO_THREADING
Martin Reinecke's avatar
Martin Reinecke committed
108
  static size_t thread_count (size_t /*nthreads*/, const fmav_info &/*info*/,
Martin Reinecke's avatar
Martin Reinecke committed
109 110 111
    size_t /*axis*/, size_t /*vlen*/)
    { return 1; }
#else
Martin Reinecke's avatar
Martin Reinecke committed
112
  static size_t thread_count (size_t nthreads, const fmav_info &info,
Martin Reinecke's avatar
Martin Reinecke committed
113 114 115
    size_t axis, size_t vlen)
    {
    if (nthreads==1) return 1;
Martin Reinecke's avatar
Martin Reinecke committed
116 117 118
    size_t size = info.size();
    size_t parallel = size / (info.shape(axis) * vlen);
    if (info.shape(axis) < 1000)
Martin Reinecke's avatar
Martin Reinecke committed
119
      parallel /= 4;
Martin Reinecke's avatar
Martin Reinecke committed
120
    size_t max_threads = (nthreads==0) ? mr::max_threads() : nthreads;
121
    return std::max(size_t(1), std::min(parallel, max_threads));
Martin Reinecke's avatar
Martin Reinecke committed
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    }
#endif
  };


//
// sine/cosine transforms
//

template<typename T0> class T_dct1
  {
  private:
    pocketfft_r<T0> fftplan;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
137
    MRUTIL_NOINLINE T_dct1(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
138 139
      : fftplan(2*(length-1)) {}

Martin Reinecke's avatar
Martin Reinecke committed
140
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho,
Martin Reinecke's avatar
Martin Reinecke committed
141 142 143 144 145 146
      int /*type*/, bool /*cosine*/) const
      {
      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
      size_t N=fftplan.length(), n=N/2+1;
      if (ortho)
        { c[0]*=sqrt2; c[n-1]*=sqrt2; }
147
      aligned_array<T> tmp(N);
Martin Reinecke's avatar
Martin Reinecke committed
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
      tmp[0] = c[0];
      for (size_t i=1; i<n; ++i)
        tmp[i] = tmp[N-i] = c[i];
      fftplan.exec(tmp.data(), fct, true);
      c[0] = tmp[0];
      for (size_t i=1; i<n; ++i)
        c[i] = tmp[2*i-1];
      if (ortho)
        { c[0]*=sqrt2*T0(0.5); c[n-1]*=sqrt2*T0(0.5); }
      }

    size_t length() const { return fftplan.length()/2+1; }
  };

template<typename T0> class T_dst1
  {
  private:
    pocketfft_r<T0> fftplan;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
168
    MRUTIL_NOINLINE T_dst1(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
169 170
      : fftplan(2*(length+1)) {}

Martin Reinecke's avatar
Martin Reinecke committed
171
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct,
Martin Reinecke's avatar
Martin Reinecke committed
172 173 174
      bool /*ortho*/, int /*type*/, bool /*cosine*/) const
      {
      size_t N=fftplan.length(), n=N/2-1;
175
      aligned_array<T> tmp(N);
Martin Reinecke's avatar
Martin Reinecke committed
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
      tmp[0] = tmp[n+1] = c[0]*0;
      for (size_t i=0; i<n; ++i)
        { tmp[i+1]=c[i]; tmp[N-1-i]=-c[i]; }
      fftplan.exec(tmp.data(), fct, true);
      for (size_t i=0; i<n; ++i)
        c[i] = -tmp[2*i+2];
      }

    size_t length() const { return fftplan.length()/2-1; }
  };

template<typename T0> class T_dcst23
  {
  private:
    pocketfft_r<T0> fftplan;
Martin Reinecke's avatar
Martin Reinecke committed
191
    std::vector<T0> twiddle;
Martin Reinecke's avatar
Martin Reinecke committed
192 193

  public:
Martin Reinecke's avatar
Martin Reinecke committed
194
    MRUTIL_NOINLINE T_dcst23(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
195 196 197 198 199 200 201
      : fftplan(length), twiddle(length)
      {
      UnityRoots<T0,Cmplx<T0>> tw(4*length);
      for (size_t i=0; i<length; ++i)
        twiddle[i] = tw[i+1].r;
      }

Martin Reinecke's avatar
Martin Reinecke committed
202
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct, bool ortho,
Martin Reinecke's avatar
Martin Reinecke committed
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
      int type, bool cosine) const
      {
      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
      size_t N=length();
      size_t NS2 = (N+1)/2;
      if (type==2)
        {
        if (!cosine)
          for (size_t k=1; k<N; k+=2)
            c[k] = -c[k];
        c[0] *= 2;
        if ((N&1)==0) c[N-1]*=2;
        for (size_t k=1; k<N-1; k+=2)
          MPINPLACE(c[k+1], c[k]);
        fftplan.exec(c, fct, false);
        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
          {
          T t1 = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
          T t2 = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
          c[k] = T0(0.5)*(t1+t2); c[kc]=T0(0.5)*(t1-t2);
          }
        if ((N&1)==0)
          c[NS2] *= twiddle[NS2-1];
        if (!cosine)
          for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
228
            std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
229 230 231 232 233 234 235
        if (ortho) c[0]*=sqrt2*T0(0.5);
        }
      else
        {
        if (ortho) c[0]*=sqrt2;
        if (!cosine)
          for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
236
            std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
        for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
          {
          T t1=c[k]+c[kc], t2=c[k]-c[kc];
          c[k] = twiddle[k-1]*t2+twiddle[kc-1]*t1;
          c[kc]= twiddle[k-1]*t1-twiddle[kc-1]*t2;
          }
        if ((N&1)==0)
          c[NS2] *= 2*twiddle[NS2-1];
        fftplan.exec(c, fct, true);
        for (size_t k=1; k<N-1; k+=2)
          MPINPLACE(c[k], c[k+1]);
        if (!cosine)
          for (size_t k=1; k<N; k+=2)
            c[k] = -c[k];
        }
      }

    size_t length() const { return fftplan.length(); }
  };

template<typename T0> class T_dcst4
  {
  private:
    size_t N;
Martin Reinecke's avatar
Martin Reinecke committed
261 262
    std::unique_ptr<pocketfft_c<T0>> fft;
    std::unique_ptr<pocketfft_r<T0>> rfft;
263
    aligned_array<Cmplx<T0>> C2;
Martin Reinecke's avatar
Martin Reinecke committed
264 265

  public:
Martin Reinecke's avatar
Martin Reinecke committed
266
    MRUTIL_NOINLINE T_dcst4(size_t length)
Martin Reinecke's avatar
Martin Reinecke committed
267 268 269 270 271 272 273 274 275 276 277 278 279
      : N(length),
        fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
        rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
        C2((N&1) ? 0 : N/2)
      {
      if ((N&1)==0)
        {
        UnityRoots<T0,Cmplx<T0>> tw(16*N);
        for (size_t i=0; i<N/2; ++i)
          C2[i] = conj(tw[8*i+1]);
        }
      }

Martin Reinecke's avatar
Martin Reinecke committed
280
    template<typename T> MRUTIL_NOINLINE void exec(T c[], T0 fct,
Martin Reinecke's avatar
Martin Reinecke committed
281 282 283 284 285
      bool /*ortho*/, int /*type*/, bool cosine) const
      {
      size_t n2 = N/2;
      if (!cosine)
        for (size_t k=0, kc=N-1; k<n2; ++k, --kc)
Martin Reinecke's avatar
Martin Reinecke committed
286
          std::swap(c[k], c[kc]);
Martin Reinecke's avatar
Martin Reinecke committed
287 288 289 290 291 292
      if (N&1)
        {
        // The following code is derived from the FFTW3 function apply_re11()
        // and is released under the 3-clause BSD license with friendly
        // permission of Matteo Frigo and Steven G. Johnson.

293
        aligned_array<T> y(N);
Martin Reinecke's avatar
Martin Reinecke committed
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
        {
        size_t i=0, m=n2;
        for (; m<N; ++i, m+=4)
          y[i] = c[m];
        for (; m<2*N; ++i, m+=4)
          y[i] = -c[2*N-m-1];
        for (; m<3*N; ++i, m+=4)
          y[i] = -c[m-2*N];
        for (; m<4*N; ++i, m+=4)
          y[i] = c[4*N-m-1];
        for (; i<N; ++i, m+=4)
          y[i] = c[m-4*N];
        }
        rfft->exec(y.data(), fct, true);
        {
        auto SGN = [](size_t i)
           {
           constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
           return (i&2) ? -sqrt2 : sqrt2;
           };
        c[n2] = y[0]*SGN(n2+1);
        size_t i=0, i1=1, k=1;
        for (; k<n2; ++i, ++i1, k+=2)
          {
          c[i    ] = y[2*k-1]*SGN(i1)     + y[2*k  ]*SGN(i);
          c[N -i1] = y[2*k-1]*SGN(N -i)   - y[2*k  ]*SGN(N -i1);
          c[n2-i1] = y[2*k+1]*SGN(n2-i)   - y[2*k+2]*SGN(n2-i1);
          c[n2+i1] = y[2*k+1]*SGN(n2+i+2) + y[2*k+2]*SGN(n2+i1);
          }
        if (k == n2)
          {
          c[i   ] = y[2*k-1]*SGN(i+1) + y[2*k]*SGN(i);
          c[N-i1] = y[2*k-1]*SGN(i+2) + y[2*k]*SGN(i1);
          }
        }

        // FFTW-derived code ends here
        }
      else
        {
        // even length algorithm from
        // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
336
        aligned_array<Cmplx<T>> y(n2);
Martin Reinecke's avatar
Martin Reinecke committed
337 338 339 340 341 342 343 344
        for(size_t i=0; i<n2; ++i)
          {
          y[i].Set(c[2*i],c[N-1-2*i]);
          y[i] *= C2[i];
          }
        fft->exec(y.data(), fct, true);
        for(size_t i=0, ic=n2-1; i<n2; ++i, --ic)
          {
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
345 346
          c[2*i  ] = T0( 2)*(y[i ].r*C2[i ].r-y[i ].i*C2[i ].i);
          c[2*i+1] = T0(-2)*(y[ic].i*C2[ic].r+y[ic].r*C2[ic].i);
Martin Reinecke's avatar
Martin Reinecke committed
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
          }
        }
      if (!cosine)
        for (size_t k=1; k<N; k+=2)
          c[k] = -c[k];
      }

    size_t length() const { return N; }
  };


//
// multi-D infrastructure
//

template<size_t N> class multi_iter
  {
  private:
Martin Reinecke's avatar
Martin Reinecke committed
365 366 367 368 369
    shape_t shp, pos;
    stride_t str_i, str_o;
    size_t cshp_i, cshp_o, rem;
    ptrdiff_t cstr_i, cstr_o, sstr_i, sstr_o, p_ii, p_i[N], p_oi, p_o[N];
    bool uni_i, uni_o;
Martin Reinecke's avatar
Martin Reinecke committed
370 371 372

    void advance_i()
      {
Martin Reinecke's avatar
Martin Reinecke committed
373
      for (size_t i=0; i<pos.size(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
374
        {
Martin Reinecke's avatar
Martin Reinecke committed
375 376 377
        p_ii += str_i[i];
        p_oi += str_o[i];
        if (++pos[i] < shp[i])
Martin Reinecke's avatar
Martin Reinecke committed
378 379
          return;
        pos[i] = 0;
Martin Reinecke's avatar
Martin Reinecke committed
380 381
        p_ii -= ptrdiff_t(shp[i])*str_i[i];
        p_oi -= ptrdiff_t(shp[i])*str_o[i];
Martin Reinecke's avatar
Martin Reinecke committed
382 383 384 385
        }
      }

  public:
Martin Reinecke's avatar
Martin Reinecke committed
386 387 388
    multi_iter(const fmav_info &iarr, const fmav_info &oarr, size_t idim,
      size_t nshares, size_t myshare)
      : rem(iarr.size()/iarr.shape(idim)), sstr_i(0), sstr_o(0), p_ii(0), p_oi(0)
Martin Reinecke's avatar
Martin Reinecke committed
389
      {
Martin Reinecke's avatar
Martin Reinecke committed
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
      MR_assert(oarr.ndim()==iarr.ndim(), "dimension mismatch");
      MR_assert(iarr.ndim()>=1, "not enough dimensions");
      // Sort the extraneous dimensions in order of ascending output stride;
      // this should improve overall cache re-use and avoid clashes between
      // threads as much as possible.
      shape_t idx(iarr.ndim());
      std::iota(idx.begin(), idx.end(), 0);
      sort(idx.begin(), idx.end(),
        [&oarr](size_t i1, size_t i2) {return oarr.stride(i1) < oarr.stride(i2);});
      for (auto i: idx)
        if (i!=idim)
          {
          pos.push_back(0);
          MR_assert(iarr.shape(i)==oarr.shape(i), "shape mismatch");
          shp.push_back(iarr.shape(i));
          str_i.push_back(iarr.stride(i));
          str_o.push_back(oarr.stride(i));
          }
      MR_assert(idim<iarr.ndim(), "bad active dimension");
      cstr_i = iarr.stride(idim);
      cstr_o = oarr.stride(idim);
      cshp_i = iarr.shape(idim);
      cshp_o = oarr.shape(idim);

// collapse unneeded dimensions
      bool done = false;
      while(!done)
        {
        done=true;
        for (size_t i=1; i<shp.size(); ++i)
          if ((str_i[i] == str_i[i-1]*ptrdiff_t(shp[i-1]))
           && (str_o[i] == str_o[i-1]*ptrdiff_t(shp[i-1])))
            {
            shp[i-1] *= shp[i];
            str_i.erase(str_i.begin()+ptrdiff_t(i));
            str_o.erase(str_o.begin()+ptrdiff_t(i));
            shp.erase(shp.begin()+ptrdiff_t(i));
            pos.pop_back();
            done=false;
            }
        }
      if (pos.size()>0)
        {
        sstr_i = str_i[0];
        sstr_o = str_o[0];
        }

Martin Reinecke's avatar
Martin Reinecke committed
437
      if (nshares==1) return;
Martin Reinecke's avatar
Martin Reinecke committed
438 439
      if (nshares==0) throw std::runtime_error("can't run with zero threads");
      if (myshare>=nshares) throw std::runtime_error("impossible share requested");
Martin Reinecke's avatar
Martin Reinecke committed
440 441 442 443 444 445 446
      size_t nbase = rem/nshares;
      size_t additional = rem%nshares;
      size_t lo = myshare*nbase + ((myshare<additional) ? myshare : additional);
      size_t hi = lo+nbase+(myshare<additional);
      size_t todo = hi-lo;

      size_t chunk = rem;
Martin Reinecke's avatar
Martin Reinecke committed
447
      for (size_t i2=0, i=pos.size()-1; i2<pos.size(); ++i2,--i)
Martin Reinecke's avatar
Martin Reinecke committed
448
        {
Martin Reinecke's avatar
Martin Reinecke committed
449
        chunk /= shp[i];
Martin Reinecke's avatar
Martin Reinecke committed
450 451
        size_t n_advance = lo/chunk;
        pos[i] += n_advance;
Martin Reinecke's avatar
Martin Reinecke committed
452 453
        p_ii += ptrdiff_t(n_advance)*str_i[i];
        p_oi += ptrdiff_t(n_advance)*str_o[i];
Martin Reinecke's avatar
Martin Reinecke committed
454 455
        lo -= n_advance*chunk;
        }
Martin Reinecke's avatar
Martin Reinecke committed
456
      MR_assert(lo==0, "must not happen");
Martin Reinecke's avatar
Martin Reinecke committed
457 458 459 460
      rem = todo;
      }
    void advance(size_t n)
      {
Martin Reinecke's avatar
Martin Reinecke committed
461
      if (rem<n) throw std::runtime_error("underrun");
Martin Reinecke's avatar
Martin Reinecke committed
462 463 464 465 466 467
      for (size_t i=0; i<n; ++i)
        {
        p_i[i] = p_ii;
        p_o[i] = p_oi;
        advance_i();
        }
Martin Reinecke's avatar
Martin Reinecke committed
468 469 470 471 472 473
      uni_i = uni_o = true;
      for (size_t i=1; i<n; ++i)
        {
        uni_i = uni_i && (p_i[i]-p_i[i-1] == sstr_i);
        uni_o = uni_o && (p_o[i]-p_o[i-1] == sstr_o);
        }
Martin Reinecke's avatar
Martin Reinecke committed
474 475
      rem -= n;
      }
Martin Reinecke's avatar
Martin Reinecke committed
476 477 478 479 480 481 482 483 484 485 486 487 488 489
    ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t iofs_uni(size_t j, size_t i) const { return p_i[0] + ptrdiff_t(j)*sstr_i + ptrdiff_t(i)*cstr_i; }
    ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i)*cstr_o; }
    ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i)*cstr_o; }
    ptrdiff_t oofs_uni(size_t j, size_t i) const { return p_o[0] + ptrdiff_t(j)*sstr_o + ptrdiff_t(i)*cstr_o; }
    bool uniform_i() const { return uni_i; } 
    ptrdiff_t unistride_i() const { return sstr_i; } 
    bool uniform_o() const { return uni_o; } 
    ptrdiff_t unistride_o() const { return sstr_o; } 
    size_t length_in() const { return cshp_i; }
    size_t length_out() const { return cshp_o; }
    ptrdiff_t stride_in() const { return cstr_i; }
    ptrdiff_t stride_out() const { return cstr_o; }
Martin Reinecke's avatar
Martin Reinecke committed
490 491 492 493 494 495 496
    size_t remaining() const { return rem; }
  };

class rev_iter
  {
  private:
    shape_t pos;
Martin Reinecke's avatar
Martin Reinecke committed
497
    fmav_info arr;
Martin Reinecke's avatar
Martin Reinecke committed
498 499
    std::vector<char> rev_axis;
    std::vector<char> rev_jump;
Martin Reinecke's avatar
Martin Reinecke committed
500 501 502 503 504 505
    size_t last_axis, last_size;
    shape_t shp;
    ptrdiff_t p, rp;
    size_t rem;

  public:
Martin Reinecke's avatar
Martin Reinecke committed
506
    rev_iter(const fmav_info &arr_, const shape_t &axes)
Martin Reinecke's avatar
Martin Reinecke committed
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
      : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
        rev_jump(arr_.ndim(), 1), p(0), rp(0)
      {
      for (auto ax: axes)
        rev_axis[ax]=1;
      last_axis = axes.back();
      last_size = arr.shape(last_axis)/2 + 1;
      shp = arr.shape();
      shp[last_axis] = last_size;
      rem=1;
      for (auto i: shp)
        rem *= i;
      }
    void advance()
      {
      --rem;
      for (int i_=int(pos.size())-1; i_>=0; --i_)
        {
        auto i = size_t(i_);
        p += arr.stride(i);
        if (!rev_axis[i])
          rp += arr.stride(i);
        else
          {
          rp -= arr.stride(i);
          if (rev_jump[i])
            {
            rp += ptrdiff_t(arr.shape(i))*arr.stride(i);
            rev_jump[i] = 0;
            }
          }
        if (++pos[i] < shp[i])
          return;
        pos[i] = 0;
        p -= ptrdiff_t(shp[i])*arr.stride(i);
        if (rev_axis[i])
          {
          rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i);
          rev_jump[i] = 1;
          }
        else
          rp -= ptrdiff_t(shp[i])*arr.stride(i);
        }
      }
    ptrdiff_t ofs() const { return p; }
    ptrdiff_t rev_ofs() const { return rp; }
    size_t remaining() const { return rem; }
  };

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
556
template<typename T, typename T0> MRUTIL_NOINLINE aligned_array<T> alloc_tmp
Martin Reinecke's avatar
Martin Reinecke committed
557
  (const fmav_info &info, size_t axsize)
Martin Reinecke's avatar
Martin Reinecke committed
558
  {
Martin Reinecke's avatar
Martin Reinecke committed
559 560 561
  auto othersize = info.size()/axsize;
  constexpr auto vlen = native_simd<T0>::size();
  auto tmpsize = axsize*((othersize>=vlen) ? vlen : 1);
Martin Reinecke's avatar
Martin Reinecke committed
562
  return aligned_array<T>(tmpsize);
Martin Reinecke's avatar
Martin Reinecke committed
563 564
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
565
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
Martin Reinecke committed
566 567 568 569 570 571 572 573
  const fmav<Cmplx<T>> &src, Cmplx<native_simd<T>> *MRUTIL_RESTRICT dst)
  {
  if (it.uniform_i())
    {
    auto ptr = &src[it.iofs_uni(0,0)];
    auto jstr = it.unistride_i();
    auto istr = it.stride_in();
    if (istr==1)
574
      for (size_t i=0; i<it.length_in(); ++i)
575 576
        {
        Cmplx<native_simd<T>> stmp;
577
        for (size_t j=0; j<vlen; ++j)
578
          {
579
          auto tmp = ptr[ptrdiff_t(j)*jstr+ptrdiff_t(i)];
580 581 582 583 584
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
Martin Reinecke's avatar
Martin Reinecke committed
585
    else if (jstr==1)
586
      for (size_t i=0; i<it.length_in(); ++i)
587 588
        {
        Cmplx<native_simd<T>> stmp;
589
        for (size_t j=0; j<vlen; ++j)
590
          {
591
          auto tmp = ptr[ptrdiff_t(j)+ptrdiff_t(i)*istr];
592 593 594 595 596
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
Martin Reinecke's avatar
Martin Reinecke committed
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
    else
      for (size_t i=0; i<it.length_in(); ++i)
        {
        Cmplx<native_simd<T>> stmp;
        for (size_t j=0; j<vlen; ++j)
          {
          auto tmp = src[it.iofs_uni(j,i)];
          stmp.r[j] = tmp.r;
          stmp.i[j] = tmp.i;
          }
        dst[i] = stmp;
        }
    }
  else
    for (size_t i=0; i<it.length_in(); ++i)
612
      {
Martin Reinecke's avatar
Martin Reinecke committed
613 614 615 616 617 618 619 620
      Cmplx<native_simd<T>> stmp;
      for (size_t j=0; j<vlen; ++j)
        {
        auto tmp = src[it.iofs(j,i)];
        stmp.r[j] = tmp.r;
        stmp.i[j] = tmp.i;
        }
      dst[i] = stmp;
621
      }
Martin Reinecke's avatar
Martin Reinecke committed
622 623
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
624
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
625
  const fmav<T> &src, native_simd<T> *MRUTIL_RESTRICT dst)
Martin Reinecke's avatar
Martin Reinecke committed
626
  {
Martin Reinecke's avatar
Martin Reinecke committed
627
  if (it.uniform_i())
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
    {
    auto ptr = &src[it.iofs_uni(0,0)];
    auto jstr = it.unistride_i();
    auto istr = it.stride_in();
    if (istr==1)
      for (size_t i=0; i<it.length_in(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst[i][j] = ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i)];
    else if (jstr==1)
      for (size_t i=0; i<it.length_in(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst[i][j] = ptr[ptrdiff_t(j) + ptrdiff_t(i)*istr];
    else
      for (size_t i=0; i<it.length_in(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst[i][j] = src[it.iofs_uni(j,i)];
    }
Martin Reinecke's avatar
Martin Reinecke committed
645
  else
646
    for (size_t i=0; i<it.length_in(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
647 648
      for (size_t j=0; j<vlen; ++j)
        dst[i][j] = src[it.iofs(j,i)];
Martin Reinecke's avatar
Martin Reinecke committed
649 650
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
651
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_input(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
652
  const fmav<T> &src, T *MRUTIL_RESTRICT dst)
Martin Reinecke's avatar
Martin Reinecke committed
653 654 655 656 657 658
  {
  if (dst == &src[it.iofs(0)]) return;  // in-place
  for (size_t i=0; i<it.length_in(); ++i)
    dst[i] = src[it.iofs(i)];
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
659
template<typename T, size_t vlen> MRUTIL_NOINLINE void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
660
  const Cmplx<native_simd<T>> *MRUTIL_RESTRICT src, fmav<Cmplx<T>> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
661
  {
Martin Reinecke's avatar
Martin Reinecke committed
662
  if (it.uniform_o())
663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
    {
    auto ptr = &dst.vraw(it.oofs_uni(0,0));
    auto jstr = it.unistride_o();
    auto istr = it.stride_out();
    if (istr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i)].Set(src[i].r[j],src[i].i[j]);
    else if (jstr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j) + ptrdiff_t(i)*istr].Set(src[i].r[j],src[i].i[j]);
    else
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst.vraw(it.oofs_uni(j,i)).Set(src[i].r[j],src[i].i[j]);
    }
Martin Reinecke's avatar
Martin Reinecke committed
680
  else
681 682
    {
    auto ptr = dst.vdata();
683
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
684 685
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]);
686
    }
Martin Reinecke's avatar
Martin Reinecke committed
687 688
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
689
template<typename T, size_t vlen> MRUTIL_NOINLINE void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
690
  const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
691
  {
Martin Reinecke's avatar
Martin Reinecke committed
692
  if (it.uniform_o())
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
    {
    auto ptr = &dst.vraw(it.oofs_uni(0,0));
    auto jstr = it.unistride_o();
    auto istr = it.stride_out();
    if (istr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i)] = src[i][j];
    else if (jstr==1)
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j) + ptrdiff_t(i)*istr] = src[i][j];
    else
      for (size_t i=0; i<it.length_out(); ++i)
        for (size_t j=0; j<vlen; ++j)
          dst.vraw(it.oofs_uni(j,i)) = src[i][j];
    }
Martin Reinecke's avatar
Martin Reinecke committed
710
  else
711 712
    {
    auto ptr=dst.vdata();
713
    for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
Martin Reinecke committed
714 715
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i)] = src[i][j];
716
    }
Martin Reinecke's avatar
Martin Reinecke committed
717 718
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
719
template<typename T, size_t vlen> MRUTIL_NOINLINE void copy_output(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
720
  const T *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
721
  {
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
722
  auto ptr=dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
723 724
  if (src == &dst[it.oofs(0)]) return;  // in-place
  for (size_t i=0; i<it.length_out(); ++i)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
725
    ptr[it.oofs(i)] = src[i];
Martin Reinecke's avatar
Martin Reinecke committed
726 727
  }

Martin Reinecke's avatar
Martin Reinecke committed
728
template <typename T> struct add_vec { using type = native_simd<T>; };
Martin Reinecke's avatar
Martin Reinecke committed
729
template <typename T> struct add_vec<Cmplx<T>>
Martin Reinecke's avatar
Martin Reinecke committed
730
  { using type = Cmplx<native_simd<T>>; };
Martin Reinecke's avatar
Martin Reinecke committed
731 732 733
template <typename T> using add_vec_t = typename add_vec<T>::type;

template<typename Tplan, typename T, typename T0, typename Exec>
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
734
MRUTIL_NOINLINE void general_nd(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
735 736 737
  const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
  const bool allow_inplace=true)
  {
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
738
  std::unique_ptr<Tplan> plan;
Martin Reinecke's avatar
Martin Reinecke committed
739 740 741 742 743

  for (size_t iax=0; iax<axes.size(); ++iax)
    {
    size_t len=in.shape(axes[iax]);
    if ((!plan) || (len!=plan->length()))
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
744
      plan = std::make_unique<Tplan>(len);
Martin Reinecke's avatar
Martin Reinecke committed
745 746

    execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
747
      util::thread_count(nthreads, in, axes[iax], native_simd<T0>::size()),
Martin Reinecke's avatar
rework  
Martin Reinecke committed
748
      [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
749 750
        constexpr auto vlen = native_simd<T0>::size();
        auto storage = alloc_tmp<T,T0>(in, len);
Martin Reinecke's avatar
Martin Reinecke committed
751
        const auto &tin(iax==0? in : out);
Martin Reinecke's avatar
rework  
Martin Reinecke committed
752
        multi_iter<vlen> it(tin, out, axes[iax], sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
753
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
754 755 756 757 758 759 760 761 762 763 764
        if (vlen>1)
          while (it.remaining()>=vlen)
            {
            it.advance(vlen);
            auto tdatav = reinterpret_cast<add_vec_t<T> *>(storage.data());
            exec(it, tin, out, tdatav, *plan, fct);
            }
#endif
        while (it.remaining()>0)
          {
          it.advance(1);
Martin Reinecke's avatar
Martin Reinecke committed
765
          auto buf = allow_inplace && it.stride_out() == 1 ?
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
766
            &out.vraw(it.oofs(0)) : reinterpret_cast<T *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
767 768 769 770 771 772 773 774 775 776 777
          exec(it, tin, out, buf, *plan, fct);
          }
      });  // end of parallel region
    fct = T0(1); // factor has been applied, use 1 for remaining axes
    }
  }

struct ExecC2C
  {
  bool forward;

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
778
  template <typename T0, typename T, size_t vlen> MRUTIL_NOINLINE void operator() (
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
779 780
    const multi_iter<vlen> &it, const fmav<Cmplx<T0>> &in,
    fmav<Cmplx<T0>> &out, T *buf, const pocketfft_c<T0> &plan, T0 fct) const
Martin Reinecke's avatar
Martin Reinecke committed
781 782 783 784 785 786 787
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, forward);
    copy_output(it, buf, out);
    }
  };

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
788
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_hartley(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
789
  const native_simd<T> *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
790
  {
Martin Reinecke's avatar
Martin Reinecke committed
791 792 793 794 795 796
  if (it.uniform_o())
    {
    auto ptr = &dst.vraw(it.oofs_uni(0,0));
    auto jstr = it.unistride_o();
    auto istr = it.stride_out();
    if (istr==1)
Martin Reinecke's avatar
Martin Reinecke committed
797
      {
Martin Reinecke's avatar
Martin Reinecke committed
798 799 800 801 802 803 804 805 806 807 808 809
      for (size_t j=0; j<vlen; ++j)
        ptr[ptrdiff_t(j)*jstr] = src[0][j];
      size_t i=1, i1=1, i2=it.length_out()-1;
      for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
        for (size_t j=0; j<vlen; ++j)
          {
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)] = src[i][j]+src[i+1][j];
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i2)] = src[i][j]-src[i+1][j];
          }
      if (i<it.length_out())
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)] = src[i][j];
Martin Reinecke's avatar
Martin Reinecke committed
810
      }
Martin Reinecke's avatar
Martin Reinecke committed
811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
    else if (jstr==1)
      {
      for (size_t j=0; j<vlen; ++j)
        ptr[ptrdiff_t(j)] = src[0][j];
      size_t i=1, i1=1, i2=it.length_out()-1;
      for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
        for (size_t j=0; j<vlen; ++j)
          {
          ptr[ptrdiff_t(j) + ptrdiff_t(i1)*istr] = src[i][j]+src[i+1][j];
          ptr[ptrdiff_t(j) + ptrdiff_t(i2)*istr] = src[i][j]-src[i+1][j];
          }
      if (i<it.length_out())
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j) + ptrdiff_t(i1)*istr] = src[i][j];
      }
    else
      {
      for (size_t j=0; j<vlen; ++j)
        ptr[ptrdiff_t(j)*jstr] = src[0][j];
      size_t i=1, i1=1, i2=it.length_out()-1;
      for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
        for (size_t j=0; j<vlen; ++j)
          {
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)*istr] = src[i][j]+src[i+1][j];
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i2)*istr] = src[i][j]-src[i+1][j];
          }
      if (i<it.length_out())
        for (size_t j=0; j<vlen; ++j)
          ptr[ptrdiff_t(j)*jstr + ptrdiff_t(i1)*istr] = src[i][j];
      }
    }
  else
    {
    auto ptr = dst.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
845
    for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
Martin Reinecke committed
846 847 848 849 850 851 852 853 854 855 856 857
      ptr[it.oofs(j,0)] = src[0][j];
    size_t i=1, i1=1, i2=it.length_out()-1;
    for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
      for (size_t j=0; j<vlen; ++j)
        {
        ptr[it.oofs(j,i1)] = src[i][j]+src[i+1][j];
        ptr[it.oofs(j,i2)] = src[i][j]-src[i+1][j];
        }
    if (i<it.length_out())
      for (size_t j=0; j<vlen; ++j)
        ptr[it.oofs(j,i1)] = src[i][j];
    }
Martin Reinecke's avatar
Martin Reinecke committed
858 859
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
860
template <typename T, size_t vlen> MRUTIL_NOINLINE void copy_hartley(const multi_iter<vlen> &it,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
861
  const T *MRUTIL_RESTRICT src, fmav<T> &dst)
Martin Reinecke's avatar
Martin Reinecke committed
862
  {
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
863 864
  auto ptr = dst.vdata();
  ptr[it.oofs(0)] = src[0];
Martin Reinecke's avatar
Martin Reinecke committed
865 866 867
  size_t i=1, i1=1, i2=it.length_out()-1;
  for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
    {
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
868 869
    ptr[it.oofs(i1)] = src[i]+src[i+1];
    ptr[it.oofs(i2)] = src[i]-src[i+1];
Martin Reinecke's avatar
Martin Reinecke committed
870 871
    }
  if (i<it.length_out())
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
872
    ptr[it.oofs(i1)] = src[i];
Martin Reinecke's avatar
Martin Reinecke committed
873 874 875 876
  }

struct ExecHartley
  {
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
877
  template <typename T0, typename T, size_t vlen> MRUTIL_NOINLINE void operator () (
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
878
    const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out,
Martin Reinecke's avatar
Martin Reinecke committed
879 880 881 882 883 884 885 886 887 888 889 890 891 892 893
    T * buf, const pocketfft_r<T0> &plan, T0 fct) const
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, true);
    copy_hartley(it, buf, out);
    }
  };

struct ExecDcst
  {
  bool ortho;
  int type;
  bool cosine;

  template <typename T0, typename T, typename Tplan, size_t vlen>
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
894
  MRUTIL_NOINLINE void operator () (const multi_iter<vlen> &it, const fmav<T0> &in,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
895
    fmav <T0> &out, T * buf, const Tplan &plan, T0 fct) const
Martin Reinecke's avatar
Martin Reinecke committed
896 897 898 899 900 901 902
    {
    copy_input(it, in, buf);
    plan.exec(buf, fct, ortho, type, cosine);
    copy_output(it, buf, out);
    }
  };

Martin Reinecke's avatar
Martin Reinecke committed
903
template<typename T> MRUTIL_NOINLINE void general_r2c(
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
904
  const fmav<T> &in, fmav<Cmplx<T>> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
905 906
  size_t nthreads)
  {
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
907
  auto plan = std::make_unique<pocketfft_r<T>>(in.shape(axis));
Martin Reinecke's avatar
Martin Reinecke committed
908 909
  size_t len=in.shape(axis);
  execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
910
    util::thread_count(nthreads, in, axis, native_simd<T>::size()),
Martin Reinecke's avatar
rework  
Martin Reinecke committed
911
    [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
912 913
    constexpr auto vlen = native_simd<T>::size();
    auto storage = alloc_tmp<T,T>(in, len);
Martin Reinecke's avatar
rework  
Martin Reinecke committed
914
    multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
915
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
916 917 918 919
    if (vlen>1)
      while (it.remaining()>=vlen)
        {
        it.advance(vlen);
Martin Reinecke's avatar
Martin Reinecke committed
920
        auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
921 922
        copy_input(it, in, tdatav);
        plan->exec(tdatav, fct, true);
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
923
        auto vout = out.vdata();
Martin Reinecke's avatar
Martin Reinecke committed
924
        for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
925
          vout[it.oofs(j,0)].Set(tdatav[0][j]);
Martin Reinecke's avatar
Martin Reinecke committed
926 927 928 929
        size_t i=1, ii=1;
        if (forward)
          for (; i<len-1; i+=2, ++ii)
            for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
930
              vout[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
Martin Reinecke's avatar
Martin Reinecke committed
931 932 933
        else
          for (; i<len-1; i+=2, ++ii)
            for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
934
              vout[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
Martin Reinecke's avatar
Martin Reinecke committed
935 936
        if (i<len)
          for (size_t j=0; j<vlen; ++j)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
937
            vout[it.oofs(j,ii)].Set(tdatav[i][j]);
Martin Reinecke's avatar
Martin Reinecke committed
938 939 940 941 942 943 944 945
        }
#endif
    while (it.remaining()>0)
      {
      it.advance(1);
      auto tdata = reinterpret_cast<T *>(storage.data());
      copy_input(it, in, tdata);
      plan->exec(tdata, fct, true);
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
946 947
      auto vout = out.vdata();
      vout[it.oofs(0)].Set(tdata[0]);
Martin Reinecke's avatar
Martin Reinecke committed
948 949 950
      size_t i=1, ii=1;
      if (forward)
        for (; i<len-1; i+=2, ++ii)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
951
          vout[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
Martin Reinecke's avatar
Martin Reinecke committed
952 953
      else
        for (; i<len-1; i+=2, ++ii)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
954
          vout[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
Martin Reinecke's avatar
Martin Reinecke committed
955
      if (i<len)
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
956
        vout[it.oofs(ii)].Set(tdata[i]);
Martin Reinecke's avatar
Martin Reinecke committed
957 958 959
      }
    });  // end of parallel region
  }
Martin Reinecke's avatar
Martin Reinecke committed
960
template<typename T> MRUTIL_NOINLINE void general_c2r(
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
961
  const fmav<Cmplx<T>> &in, fmav<T> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
962 963
  size_t nthreads)
  {
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
964
  auto plan = std::make_unique<pocketfft_r<T>>(out.shape(axis));
Martin Reinecke's avatar
Martin Reinecke committed
965 966
  size_t len=out.shape(axis);
  execParallel(
Martin Reinecke's avatar
Martin Reinecke committed
967
    util::thread_count(nthreads, in, axis, native_simd<T>::size()),
Martin Reinecke's avatar
rework  
Martin Reinecke committed
968
    [&](Scheduler &sched) {
Martin Reinecke's avatar
Martin Reinecke committed
969 970
      constexpr auto vlen = native_simd<T>::size();
      auto storage = alloc_tmp<T,T>(out, len);
Martin Reinecke's avatar
rework  
Martin Reinecke committed
971
      multi_iter<vlen> it(in, out, axis, sched.num_threads(), sched.thread_num());
Martin Reinecke's avatar
Martin Reinecke committed
972
#ifndef MRUTIL_NO_SIMD
Martin Reinecke's avatar
Martin Reinecke committed
973 974 975 976
      if (vlen>1)
        while (it.remaining()>=vlen)
          {
          it.advance(vlen);
Martin Reinecke's avatar
Martin Reinecke committed
977
          auto tdatav = reinterpret_cast<native_simd<T> *>(storage.data());
Martin Reinecke's avatar
Martin Reinecke committed
978 979 980 981 982 983 984
          for (size_t j=0; j<vlen; ++j)
            tdatav[0][j]=in[it.iofs(j,0)].r;
          {
          size_t i=1, ii=1;
          if (forward)
            for (; i<len-1; i+=2, ++ii)
              for (size_t j=0; j<vlen; ++j)
985 986 987 988
                {
                tdatav[i  ][j] =  in[it.iofs(j,ii)].r;
                tdatav[i+1][j] = -in[it.iofs(j,ii)].i;
                }
Martin Reinecke's avatar
Martin Reinecke committed
989 990 991
          else
            for (; i<len-1; i+=2, ++ii)
              for (size_t j=0; j<vlen; ++j)
992 993 994 995
                {
                tdatav[i  ][j] = in[it.iofs(j,ii)].r;
                tdatav[i+1][j] = in[it.iofs(j,ii)].i;
                }
Martin Reinecke's avatar
Martin Reinecke committed
996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
          if (i<len)
            for (size_t j=0; j<vlen; ++j)
              tdatav[i][j] = in[it.iofs(j,ii)].r;
          }
          plan->exec(tdatav, fct, false);
          copy_output(it, tdatav, out);
          }
#endif
      while (it.remaining()>0)
        {
        it.advance(1);
        auto tdata = reinterpret_cast<T *>(storage.data());
        tdata[0]=in[it.iofs(0)].r;
        {
        size_t i=1, ii=1;
        if (forward)
          for (; i<len-1; i+=2, ++ii)
1013 1014 1015 1016
            {
            tdata[i  ] =  in[it.iofs(ii)].r;
            tdata[i+1] = -in[it.iofs(ii)].i;
            }
Martin Reinecke's avatar
Martin Reinecke committed
1017 1018
        else
          for (; i<len-1; i+=2, ++ii)
1019 1020 1021 1022
            {
            tdata[i  ] = in[it.iofs(ii)].r;
            tdata[i+1] = in[it.iofs(ii)].i;
            }
Martin Reinecke's avatar
Martin Reinecke committed
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
        if (i<len)
          tdata[i] = in[it.iofs(ii)].r;
        }
        plan->exec(tdata, fct, false);
        copy_output(it, tdata, out);
        }
    });  // end of parallel region
  }

struct ExecR2R
  {
  bool r2c, forward;

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
1036
  template <typename T0, typename T, size_t vlen> MRUTIL_NOINLINE void operator () (
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
1037
    const multi_iter<vlen> &it, const fmav<T0> &in, fmav<T0> &out, T * buf,
Martin Reinecke's avatar
Martin Reinecke committed
1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
    const pocketfft_r<T0> &plan, T0 fct) const
    {
    copy_input(it, in, buf);
    if ((!r2c) && forward)
      for (size_t i=2; i<it.length_out(); i+=2)
        buf[i] = -buf[i];
    plan.exec(buf, fct, forward);
    if (r2c && (!forward))
      for (size_t i=2; i<it.length_out(); i+=2)
        buf[i] = -buf[i];
    copy_output(it, buf, out);
    }
  };

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
1052
template<typename T> MRUTIL_NOINLINE void c2c(const fmav<std::complex<T>> &in,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
1053
  fmav<std::complex<T>> &out, const shape_t &axes, bool forward,
Martin Reinecke's avatar
Martin Reinecke committed
1054
  T fct, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1055
  {
Martin Reinecke's avatar
Martin Reinecke committed
1056 1057
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
1058
  fmav<Cmplx<T>> in2(reinterpret_cast<const Cmplx<T> *>(in.data()), in);
Martin Reinecke's avatar
stage 2  
Martin Reinecke committed
1059
  fmav<Cmplx<T>> out2(reinterpret_cast<Cmplx<T> *>(out.vdata()), out, out.writable());
Martin Reinecke's avatar
Martin Reinecke committed
1060
  general_nd<pocketfft_c<T>>(in2, out2, axes, fct, nthreads, ExecC2C{forward});
Martin Reinecke's avatar
Martin Reinecke committed
1061 1062
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
1063
template<typename T> MRUTIL_NOINLINE void dct(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
1064
  const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1065
  {
Martin Reinecke's avatar
Martin Reinecke committed
1066
  if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
Martin Reinecke's avatar
Martin Reinecke committed
1067 1068
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
  if (in.size()==0) return;
Martin Reinecke's avatar
Martin Reinecke committed
1069 1070
  const ExecDcst exec{ortho, type, true};
  if (type==1)
Martin Reinecke's avatar
Martin Reinecke committed
1071
    general_nd<T_dct1<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1072
  else if (type==4)
Martin Reinecke's avatar
Martin Reinecke committed
1073
    general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1074
  else
Martin Reinecke's avatar
Martin Reinecke committed
1075
    general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1076 1077
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
1078
template<typename T> MRUTIL_NOINLINE void dst(const fmav<T> &in, fmav<T> &out,
Martin Reinecke's avatar
Martin Reinecke committed
1079
  const shape_t &axes, int type, T fct, bool ortho, size_t nthreads=1)
Martin Reinecke's avatar
Martin Reinecke committed
1080
  {
Martin Reinecke's avatar
Martin Reinecke committed
1081
  if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
Martin Reinecke's avatar
Martin Reinecke committed
1082
  util::sanity_check_onetype(in, out, in.data()==out.data(), axes);
Martin Reinecke's avatar
Martin Reinecke committed
1083 1084
  const ExecDcst exec{ortho, type, false};
  if (type==1)
Martin Reinecke's avatar
Martin Reinecke committed
1085
    general_nd<T_dst1<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1086
  else if (type==4)
Martin Reinecke's avatar
Martin Reinecke committed
1087
    general_nd<T_dcst4<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1088
  else
Martin Reinecke's avatar
Martin Reinecke committed
1089
    general_nd<T_dcst23<T>>(in, out, axes, fct, nthreads, exec);
Martin Reinecke's avatar
Martin Reinecke committed
1090 1091
  }

Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
1092
template<typename T> MRUTIL_NOINLINE void r2c(const fmav<T> &in,
Martin Reinecke's avatar
stage 1  
Martin Reinecke committed
1093
  fmav<std::complex<T>> &out, size_t axis, bool forward, T fct,
Martin Reinecke's avatar
Martin Reinecke committed
1094 1095
  size_t nthreads=1)
  {
Martin Reinecke's avatar
Martin Reinecke committed
1096 1097
  util::sanity_check_cr