diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 1923c15c3c27c2c4cac36dfc45a3693022fea91c..f7619a065f0bf34d0bcf2354184b5a00abd7bb17 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2344,8 +2344,8 @@ template class T_cosq if (N==2) { T TSQX = sqrt2*c[1]; - c[1] = c[0]-TSQX; - c[0] = c[0]+TSQX; + c[1] = fct*(c[0]-TSQX); + c[0] = fct*(c[0]+TSQX); return; } size_t NS2 = (N+1)/2; diff --git a/pypocketfft.cc b/pypocketfft.cc index 608001ee6ed87fbd2220eb42d063858fc3f3676d..503afdd601084f8d3986c1f3e8e3ce4a0a3d0f48 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -240,6 +240,8 @@ template py::array r2r_dct23_internal(const py::array &in, { py::gil_scoped_release release; T fct = norm_fct(inorm, dims, axes); + if (inorm==2) fct*=T(1/ldbl_t(2)); + if (inorm==1) fct*=T(1/sqrt(ldbl_t(2))); pocketfft::r2r_dct23(dims, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads); } @@ -267,6 +269,8 @@ template py::array r2r_dst23_internal(const py::array &in, { py::gil_scoped_release release; T fct = norm_fct(inorm, dims, axes); + if (inorm==2) fct*=T(1/ldbl_t(2)); + if (inorm==1) fct*=T(1/sqrt(ldbl_t(2))); pocketfft::r2r_dst23(dims, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads); } diff --git a/test.py b/test.py index 5c43f82776355e9d1920476c42467cc23c4f172d..17610edc25a85cf6cdcf95c6c8782c21948e1094 100644 --- a/test.py +++ b/test.py @@ -17,6 +17,13 @@ def _l2error(a, b): return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) +def _assert_close(a, b, epsilon): + err = _l2error(a, b) + if (err >= epsilon): + print("Error: {} > {}".format(err, epsilon)) + assert_(err