diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h
index b875db83f9599ef4cb4fa58a96f9e2791e2cc7f2..ac2ac7f85dd449bfb1aaa2ec408b522f1fd7b3d3 100644
--- a/pocketfft_hdronly.h
+++ b/pocketfft_hdronly.h
@@ -2293,6 +2293,7 @@ template<typename T0> class pocketfft_r
 //
 // sine/cosine transforms
 //
+
 template<typename T0> class T_dct1
   {
   private:
@@ -2302,7 +2303,8 @@ template<typename T0> class T_dct1
     POCKETFFT_NOINLINE T_dct1(size_t length)
       : fftplan(2*(length-1)) {}
 
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
+    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
+      int /*type*/, bool /*cosine*/) const
       {
       constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
       size_t N=fftplan.length(), n=N/2+1;
@@ -2323,22 +2325,42 @@ template<typename T0> class T_dct1
     size_t length() const { return fftplan.length()/2+1; }
   };
 
-template<typename T0> class T_dct2
+template<typename T0> class T_dst1
   {
   private:
     pocketfft_r<T0> fftplan;
-    vector<T0> twiddle;
 
   public:
-    POCKETFFT_NOINLINE T_dct2(size_t length)
-      : fftplan(length), twiddle(length)
+    POCKETFFT_NOINLINE T_dst1(size_t length)
+      : fftplan(2*(length+1)) {}
+
+    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
+      bool /*ortho*/, int /*type*/, bool /*cosine*/) const
       {
-      constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
-      for (size_t i=0; i<length; ++i)
-        twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
+      size_t N=fftplan.length(), n=N/2-1;
+      arr<T> tmp(N);
+      tmp[0] = tmp[n+1] = c[0]*0;
+      for (size_t i=0; i<n; ++i)
+        {
+        tmp[i+1] = c[i];
+        tmp[N-1-i] = -c[i];
+        }
+      fftplan.forward(tmp.data(), fct);
+      for (size_t i=0; i<n; ++i)
+        c[i] = -tmp[2*i+2];
       }
 
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
+    size_t length() const { return fftplan.length()/2-1; }
+  };
+
+template<typename T0> class T_dcst23
+  {
+  private:
+    pocketfft_r<T0> fftplan;
+    vector<T0> twiddle;
+
+    template<typename T> POCKETFFT_NOINLINE void exec2c(T c[], T0 fct,
+      bool ortho) const
       {
       constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
       size_t N=length();
@@ -2368,25 +2390,8 @@ template<typename T0> class T_dct2
       if (ortho) c[0]/=sqrt2;
       }
 
-    size_t length() const { return fftplan.length(); }
-  };
-
-template<typename T0> class T_dct3
-  {
-  private:
-    pocketfft_r<T0> fftplan;
-    vector<T0> twiddle;
-
-  public:
-    POCKETFFT_NOINLINE T_dct3(size_t length)
-      : fftplan(length), twiddle(length)
-      {
-      constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
-      for (size_t i=0; i<length; ++i)
-        twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
-      }
-
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
+    template<typename T> POCKETFFT_NOINLINE void exec3c(T c[], T0 fct,
+      bool ortho) const
       {
       constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
       size_t N=length();
@@ -2417,10 +2422,55 @@ template<typename T0> class T_dct3
         }
       }
 
+    template<typename T> POCKETFFT_NOINLINE void exec2s(T c[], T0 fct,
+      bool ortho) const
+      {
+      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
+      size_t N=length();
+      for (size_t k=1; k<N; k+=2)
+        c[k] = -c[k];
+      exec2c(c, fct, false);
+      for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
+        swap(c[k], c[kc]);
+      if (ortho) c[0]/=sqrt2;
+      }
+
+    template<typename T> POCKETFFT_NOINLINE void exec3s(T c[], T0 fct,
+      bool ortho) const
+      {
+      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
+      size_t N=length();
+      if (ortho) c[0]*=sqrt2;
+      size_t NS2 = N/2;
+      for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
+        swap(c[k], c[kc]);
+      exec3c(c, fct, false);
+      for (size_t k=1; k<N; k+=2)
+        c[k] = -c[k];
+      }
+
+  public:
+    POCKETFFT_NOINLINE T_dcst23(size_t length)
+      : fftplan(length), twiddle(length)
+      {
+      constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
+      for (size_t i=0; i<length; ++i)
+        twiddle[i] = T0(cos(0.5*pi*T0(i+1)/T0(length)));
+      }
+
+    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
+      int type, bool cosine) const
+      {
+      if (type==2)
+        cosine ? exec2c(c, fct, ortho) : exec2s(c, fct, ortho);
+      else
+        cosine ? exec3c(c, fct, ortho) : exec3s(c, fct, ortho);
+      }
+
     size_t length() const { return fftplan.length(); }
   };
 
-template<typename T0> class T_dct4
+template<typename T0> class T_dcst4
   {
   // even length algorithm from
   // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
@@ -2430,23 +2480,7 @@ template<typename T0> class T_dct4
     unique_ptr<pocketfft_r<T0>> rfft;
     arr<cmplx<T0>> C2;
 
-  public:
-    POCKETFFT_NOINLINE T_dct4(size_t length)
-      : N(length),
-        fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
-        rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
-        C2((N&1) ? 0 : N/2)
-      {
-      constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
-      if ((N&1)==0)
-        for (size_t i=0; i<N/2; ++i)
-          {
-          T0 ang = -pi/T0(N)*(T0(i)+T0(0.125));
-          C2[i].Set(cos(ang), sin(ang));
-          }
-      }
-
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
+    template<typename T> POCKETFFT_NOINLINE void exec4c(T c[], T0 fct) const
       {
       constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
       if (N&1)
@@ -2512,106 +2546,38 @@ template<typename T0> class T_dct4
         }
       }
 
-    size_t length() const { return N; }
-  };
-
-template<typename T0> class T_dst1
-  {
-  private:
-    pocketfft_r<T0> fftplan;
-
-  public:
-    POCKETFFT_NOINLINE T_dst1(size_t length)
-      : fftplan(2*(length+1)) {}
-
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/) const
+    template<typename T> POCKETFFT_NOINLINE void exec4s(T c[], T0 fct) const
       {
-      size_t N=fftplan.length(), n=N/2-1;
-      arr<T> tmp(N);
-      tmp[0] = tmp[n+1] = c[0]*0;
-      for (size_t i=0; i<n; ++i)
-        {
-        tmp[i+1] = c[i];
-        tmp[N-1-i] = -c[i];
-        }
-      fftplan.forward(tmp.data(), fct);
-      for (size_t i=0; i<n; ++i)
-        c[i] = -tmp[2*i+2];
-      }
-
-    size_t length() const { return fftplan.length()/2-1; }
-  };
-
-template<typename T0> class T_dst2
-  {
-  private:
-    T_dct2<T0> dct;
-
-  public:
-    POCKETFFT_NOINLINE T_dst2(size_t length)
-      : dct(length) {}
-
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho) const
-      {
-      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
       size_t N=length();
-      for (size_t k=1; k<N; k+=2)
-        c[k] = -c[k];
-      dct.exec(c, fct, false);
-      for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
-        swap(c[k], c[kc]);
-      if (ortho) c[0]/=sqrt2;
-      }
-
-    size_t length() const { return dct.length(); }
-  };
-
-template<typename T0> class T_dst3
-  {
-  private:
-    T_dct3<T0> dct;
-
-  public:
-    POCKETFFT_NOINLINE T_dst3(size_t length)
-      : dct(length) {}
-
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho)
-      {
-      constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
-      size_t N=length();
-      if (ortho) c[0]*=sqrt2;
       size_t NS2 = N/2;
       for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
         swap(c[k], c[kc]);
-      dct.exec(c, fct, false);
+      exec4c(c, fct);
       for (size_t k=1; k<N; k+=2)
         c[k] = -c[k];
       }
 
-    size_t length() const { return dct.length(); }
-  };
-
-template<typename T0> class T_dst4
-  {
-  private:
-    T_dct4<T0> dct;
-
   public:
-    POCKETFFT_NOINLINE T_dst4(size_t length)
-      : dct(length) {}
-
-    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool /*ortho*/)
+    POCKETFFT_NOINLINE T_dcst4(size_t length)
+      : N(length),
+        fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
+        rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
+        C2((N&1) ? 0 : N/2)
       {
-      size_t N=length();
-      size_t NS2 = N/2;
-      for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
-        swap(c[k], c[kc]);
-      dct.exec(c, fct, false);
-      for (size_t k=1; k<N; k+=2)
-        c[k] = -c[k];
+      constexpr T0 pi = T0(3.141592653589793238462643383279502884197L);
+      if ((N&1)==0)
+        for (size_t i=0; i<N/2; ++i)
+          {
+          T0 ang = -pi/T0(N)*(T0(i)+T0(0.125));
+          C2[i].Set(cos(ang), sin(ang));
+          }
       }
 
-    size_t length() const { return dct.length(); }
+    template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
+      bool /*ortho*/, int /*type*/, bool cosine) const
+      { cosine ? exec4c(c, fct) : exec4s(c, fct); }
+
+    size_t length() const { return N; }
   };
 
 
@@ -3060,7 +3026,7 @@ template<typename T> POCKETFFT_NOINLINE void general_hartley(
 
 template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
   const cndarr<T> &in, ndarr<T> &out, const shape_t &axes,
-  T fct, bool ortho, size_t POCKETFFT_NTHREADS)
+  T fct, bool ortho, int type, bool cosine, size_t POCKETFFT_NTHREADS)
   {
   shared_ptr<Trafo> plan;
 
@@ -3088,7 +3054,7 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
         for (size_t i=0; i<len; ++i)
           for (size_t j=0; j<vlen; ++j)
             tdatav[i][j] = tin[it.iofs(j,i)];
-        plan->exec(tdatav, fct, ortho);
+        plan->exec(tdatav, fct, ortho, type, cosine);
         for (size_t i=0; i<len; ++i)
           for (size_t j=0; j<vlen; ++j)
             out[it.oofs(j,i)] = tdatav[i][j];
@@ -3099,18 +3065,18 @@ template<typename Trafo, typename T> POCKETFFT_NOINLINE void general_dcst(
       it.advance(1);
       auto tdata = reinterpret_cast<T *>(storage.data());
       if ((&tin[0]==&out[0]) && (it.stride_out()==sizeof(T))) // fully in-place
-        plan->exec(&out[it.oofs(0)], fct, ortho);
+        plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine);
       else if (it.stride_out()==sizeof(T)) // compute FFT in output location
         {
         for (size_t i=0; i<len; ++i)
           out[it.oofs(i)] = tin[it.iofs(i)];
-        plan->exec(&out[it.oofs(0)], fct, ortho);
+        plan->exec(&out[it.oofs(0)], fct, ortho, type, cosine);
         }
       else
         {
         for (size_t i=0; i<len; ++i)
           tdata[i] = tin[it.iofs(i)];
-        plan->exec(tdata, fct, ortho);
+        plan->exec(tdata, fct, ortho, type, cosine);
         for (size_t i=0; i<len; ++i)
           out[it.oofs(i)] = tdata[i];
         }
@@ -3374,13 +3340,13 @@ template<typename T> void dct(const shape_t &shape,
   cndarr<T> ain(data_in, shape, stride_in);
   ndarr<T> aout(data_out, shape, stride_out);
   if (type==1)
-    general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dct1<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
   else if (type==2)
-    general_dcst<T_dct2<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
   else if (type==3)
-    general_dcst<T_dct3<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
   else if (type==4)
-    general_dcst<T_dct4<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, true, nthreads);
   else
     throw runtime_error("unsupported DCT type");
   }
@@ -3395,13 +3361,13 @@ template<typename T> void dst(const shape_t &shape,
   cndarr<T> ain(data_in, shape, stride_in);
   ndarr<T> aout(data_out, shape, stride_out);
   if (type==1)
-    general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dst1<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
   else if (type==2)
-    general_dcst<T_dst2<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
   else if (type==3)
-    general_dcst<T_dst3<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dcst23<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
   else if (type==4)
-    general_dcst<T_dst4<T>>(ain, aout, axes, fct, ortho, nthreads);
+    general_dcst<T_dcst4<T>>(ain, aout, axes, fct, ortho, type, false, nthreads);
   else
     throw runtime_error("unsupported DST type");
   }