Commit d68a9f1e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add normalization

parent 4e178a24
......@@ -784,7 +784,7 @@ template<bool bwd, typename T> NOINLINE void passg (size_t ido, size_t ip,
template<bool bwd, typename T> NOINLINE void pass_all(T c[], T0 fact)
{
if (length==1) return;
if (length==1) { c[0]*=fact; return; }
size_t l1=1;
arr<T> ch(length);
T *p1=c, *p2=ch.data();
......@@ -1574,7 +1574,7 @@ template<typename T> static void copy_and_norm(T *c, T *p1, size_t n, T0 fct)
template<typename T> void forward(T c[], T0 fact)
{
if (length==1) return;
if (length==1) { c[0]*=fact; return; }
size_t n=length;
size_t l1=n, nf=nfct;
arr<T> ch(n);
......@@ -1606,7 +1606,7 @@ template<typename T> void forward(T c[], T0 fact)
template<typename T> void backward(T c[], T0 fact)
{
if (length==1) return;
if (length==1) { c[0]*=fact; return; }
size_t n=length;
size_t l1=1, nf=nfct;
arr<T> ch(n);
......@@ -2203,7 +2203,7 @@ void check_args(const py::array &in, vector<size_t> &axes)
throw runtime_error("invalid axis number");
}
py::array execute(const py::array &in, vector<size_t> axes, bool fwd)
py::array execute(const py::array &in, vector<size_t> axes, const string &norm, bool fwd)
{
check_args(in, axes);
py::array res;
......@@ -2211,6 +2211,15 @@ py::array execute(const py::array &in, vector<size_t> axes, bool fwd)
vector<int64_t> s_i(in.ndim()), s_o(in.ndim());
for (int i=0; i<in.ndim(); ++i)
dims[i] = in.shape(i);
size_t n_elem=1;
for (size_t i=0; i<axes.size(); ++i)
n_elem *= in.shape(axes[i]);
double fct=1;
if (norm=="ortho")
fct=1./sqrt(n_elem);
else
if (!fwd)
fct = 1./n_elem;
if (in.dtype().is(c128))
res = py::array_t<complex<double>>(dims);
else if (in.dtype().is(c64))
......@@ -2229,20 +2238,20 @@ py::array execute(const py::array &in, vector<size_t> axes, bool fwd)
pocketfft_general_c<double>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<double> *)in.data(),
(cmplx<double> *)res.mutable_data(), 1.);
(cmplx<double> *)res.mutable_data(), fct);
else
pocketfft_general_c<float>(in.ndim(), dims.data(),
s_i.data(), s_o.data(), axes.size(),
axes.data(), fwd, (const cmplx<float> *)in.data(),
(cmplx<float> *)res.mutable_data(), 1.);
(cmplx<float> *)res.mutable_data(), fct);
return res;
}
py::array fftn(const py::array &a, const vector<size_t> &axes)
{ return execute(a, axes, true); }
py::array ifftn(const py::array &a, const vector<size_t> &axes)
{ return execute(a, axes, false); }
py::array fftn(const py::array &a, const vector<size_t> &axes, const string &norm)
{ return execute(a, axes, norm, true); }
py::array ifftn(const py::array &a, const vector<size_t> &axes, const string &norm)
{ return execute(a, axes, norm, false); }
py::array rfftn(const py::array &in, vector<size_t> axes)
py::array rfftn(const py::array &in, vector<size_t> axes, const string &norm)
{
check_args(in, axes);
py::array res;
......@@ -2254,6 +2263,12 @@ py::array rfftn(const py::array &in, vector<size_t> axes)
dims_out[i] = in.shape(i);
}
dims_out[axes.back()] = (dims_out[axes.back()]>>1)+1;
size_t n_elem=1;
for (size_t i=0; i<axes.size(); ++i)
n_elem *= in.shape(axes[i]);
double fct=1;
if (norm=="ortho")
fct=1./sqrt(n_elem);
if (in.dtype().is(f64))
res = py::array_t<complex<double>>(dims_out);
else if (in.dtype().is(f32))
......@@ -2273,7 +2288,7 @@ py::array rfftn(const py::array &in, vector<size_t> axes)
pocketfft_general_r2c<double>(in.ndim(), dims_in.data(),
s_i.data(), s_o.data(),
axes.back(), (const double *)in.data(),
(cmplx<double> *)res.mutable_data(), 1.);
(cmplx<double> *)res.mutable_data(), fct);
pocketfft_general_c<double>(res.ndim(), dims_out.data(),
s_o.data(), s_o.data(), axes.size()-1,
axes.data(), true, (const cmplx<double> *)res.data(),
......@@ -2284,7 +2299,7 @@ py::array rfftn(const py::array &in, vector<size_t> axes)
pocketfft_general_r2c<float>(in.ndim(), dims_in.data(),
s_i.data(), s_o.data(),
axes.back(), (const float *)in.data(),
(cmplx<float> *)res.mutable_data(), 1.);
(cmplx<float> *)res.mutable_data(), fct);
pocketfft_general_c<float>(res.ndim(), dims_out.data(),
s_o.data(), s_o.data(), axes.size()-1,
axes.data(), true, (const cmplx<float> *)res.data(),
......@@ -2292,7 +2307,7 @@ py::array rfftn(const py::array &in, vector<size_t> axes)
}
return res;
}
py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize)
py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize, const string &norm)
{
check_args(in, axes);
py::array inter;
......@@ -2301,7 +2316,7 @@ py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize)
vector<size_t> axes2(axes.size()-1);
for (size_t i=0; i<axes2.size(); ++i)
axes2[i] = axes[i];
inter = ifftn(in, axes2);
inter = ifftn(in, axes2, norm);
}
else
inter = in;
......@@ -2318,6 +2333,11 @@ py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize)
dims_out[i] = inter.shape(i);
}
dims_out[axis] = lastsize;
double fct=1;
if (norm=="ortho")
fct=1./sqrt(lastsize);
else
fct = 1./lastsize;
py::array res;
if (inter.dtype().is(c128))
res = py::array_t<double>(dims_out);
......@@ -2337,12 +2357,12 @@ py::array irfftn(const py::array &in, vector<size_t> axes, size_t lastsize)
pocketfft_general_c2r<double>(inter.ndim(), dims_out.data(),
s_i.data(), s_o.data(),
axis, (const cmplx<double> *)inter.data(),
(double *)res.mutable_data(), 1.);
(double *)res.mutable_data(), fct);
else
pocketfft_general_c2r<float>(inter.ndim(), dims_out.data(),
s_i.data(), s_o.data(),
axis, (const cmplx<float> *)inter.data(),
(float *)res.mutable_data(), 1.);
(float *)res.mutable_data(), fct);
return res;
}
py::array hartley(const py::array &in, vector<size_t> axes)
......@@ -2379,22 +2399,15 @@ py::array hartley(const py::array &in, vector<size_t> axes)
(float *)res.mutable_data(), 1.);
return res;
}
void xtest(const py::array &a, const vector<size_t> &s, const vector<size_t> &axes, const string &norm)
{
// for (size_t i=0; i<data.size(); ++i)
// cout << data[i] << endl;
}
} // unnamed namespace
PYBIND11_MODULE(pypocketfft, m)
{
using namespace pybind11::literals;
m.def("fftn",&fftn, "a"_a, "axes"_a=vector<size_t>(1,10000));
m.def("ifftn",&ifftn, "a"_a, "axes"_a=vector<size_t>(1,10000));
m.def("rfftn",&rfftn, "a"_a, "axes"_a=vector<size_t>(1,10000));
m.def("irfftn",&irfftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "lastsize"_a=0);
m.def("fftn",&fftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "norm"_a=string(""));
m.def("ifftn",&ifftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "norm"_a=string(""));
m.def("rfftn",&rfftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "norm"_a=string(""));
m.def("irfftn",&irfftn, "a"_a, "axes"_a=vector<size_t>(1,10000), "lastsize"_a=0, "norm"_a=string(""));
m.def("hartley",&hartley, "a"_a, "axes"_a=vector<size_t>(1,10000));
m.def("xtest",&xtest,"a"_a,"s"_a=vector<size_t>(), "axes"_a=vector<size_t>(), "norm"_a="");
}
......@@ -16,6 +16,7 @@ def test_fftn(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
vmax = np.max(np.abs(pyfftw.interfaces.numpy_fft.fftn(a)))
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a), pypocketfft.fftn(a), atol=2e-15*vmax, rtol=0)
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a,norm="ortho"), pypocketfft.fftn(a,norm="ortho"), atol=2e-15*vmax, rtol=0)
a=a.astype(np.complex64)
assert_allclose(pyfftw.interfaces.numpy_fft.fftn(a), pypocketfft.fftn(a), atol=6e-7*vmax, rtol=0)
......@@ -33,6 +34,7 @@ def test_rfftn(shp):
a=np.random.rand(*shp)-0.5
vmax = np.max(np.abs(pyfftw.interfaces.numpy_fft.fftn(a)))
assert_allclose(pyfftw.interfaces.numpy_fft.rfftn(a), pypocketfft.rfftn(a), atol=2e-15*vmax, rtol=0)
assert_allclose(pyfftw.interfaces.numpy_fft.rfftn(a,norm="ortho"), pypocketfft.rfftn(a,norm="ortho"), atol=2e-15*vmax, rtol=0)
a=a.astype(np.float32)
assert_allclose(pyfftw.interfaces.numpy_fft.rfftn(a), pypocketfft.rfftn(a), atol=6e-7*vmax, rtol=0)
......@@ -49,19 +51,21 @@ def test_rfftn2D(shp, axes):
def test_identity(shp):
a=np.random.rand(*shp)-0.5 + 1j*np.random.rand(*shp)-0.5j
vmax = np.max(np.abs(a))
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a))/a.size, a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a)), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a,norm="ortho"), norm="ortho"), a, atol=2e-15*vmax, rtol=0)
a=a.astype(np.complex64)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a))/a.size, a, atol=6e-7*vmax, rtol=0)
assert_allclose(pypocketfft.ifftn(pypocketfft.fftn(a)), a, atol=6e-7*vmax, rtol=0)
@pmp("shp", shapes)
def xtest_identity_r(shp):
def test_identity_r(shp):
a=np.random.rand(*shp)-0.5
b=a.astype(np.float32)
vmax = np.max(np.abs(a))
for ax in range(a.ndim):
n = a.shape[ax]
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(a,ax),ax)/n, a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(b,ax),ax)/n, b, atol=6e-7*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(a,(ax,)),(ax,),n), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(a,(ax,),norm="ortho"),(ax,),n,norm="ortho"), a, atol=2e-15*vmax, rtol=0)
assert_allclose(pypocketfft.irfftn(pypocketfft.rfftn(b,(ax,)),(ax,),n), b, atol=6e-7*vmax, rtol=0)
@pmp("shp", shapes3D)
def xtest_hartley(shp):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment