From 0f5781c927b271f00f6ce6d76404220832ae5b8a Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Wed, 17 Jul 2019 19:57:38 +0200 Subject: [PATCH] fix normalization; add tests --- pocketfft_hdronly.h | 4 ++-- pypocketfft.cc | 4 ++++ test.py | 23 ++++++++++++++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index 1923c15..f7619a0 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -2344,8 +2344,8 @@ template class T_cosq if (N==2) { T TSQX = sqrt2*c[1]; - c[1] = c[0]-TSQX; - c[0] = c[0]+TSQX; + c[1] = fct*(c[0]-TSQX); + c[0] = fct*(c[0]+TSQX); return; } size_t NS2 = (N+1)/2; diff --git a/pypocketfft.cc b/pypocketfft.cc index 608001e..503afdd 100644 --- a/pypocketfft.cc +++ b/pypocketfft.cc @@ -240,6 +240,8 @@ template py::array r2r_dct23_internal(const py::array &in, { py::gil_scoped_release release; T fct = norm_fct(inorm, dims, axes); + if (inorm==2) fct*=T(1/ldbl_t(2)); + if (inorm==1) fct*=T(1/sqrt(ldbl_t(2))); pocketfft::r2r_dct23(dims, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads); } @@ -267,6 +269,8 @@ template py::array r2r_dst23_internal(const py::array &in, { py::gil_scoped_release release; T fct = norm_fct(inorm, dims, axes); + if (inorm==2) fct*=T(1/ldbl_t(2)); + if (inorm==1) fct*=T(1/sqrt(ldbl_t(2))); pocketfft::r2r_dst23(dims, s_in, s_out, axes, forward, d_in, d_out, fct, nthreads); } diff --git a/test.py b/test.py index 5c43f82..17610ed 100644 --- a/test.py +++ b/test.py @@ -17,6 +17,13 @@ def _l2error(a, b): return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2)) +def _assert_close(a, b, epsilon): + err = _l2error(a, b) + if (err >= epsilon): + print("Error: {} > {}".format(err, epsilon)) + assert_(err