bench_nd.py 1.53 KB
Newer Older
1
2
3
4
5
6
7
# Don't run this benchmark with numpy<1.17 ... it will probably take ages!

import numpy as np
import pypocketfft
from time import time
import matplotlib.pyplot as plt

8
nthreads=1
9
10
def _l2error(a,b):
    return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

def bench_nd_fftn(ndim, nmax, ntry, tp, nrepeat, filename=""):
    res=[]
    for n in range(ntry):
        shp = np.random.randint(1,nmax+1,ndim)
        a=(np.random.rand(*shp) + 1j*np.random.rand(*shp)).astype(tp)
        tmin_np=1e38
        for i in range(nrepeat):
            t0=time()
            b=np.fft.fftn(a)
            t1=time()
            tmin_np = min(tmin_np,t1-t0)
        tmin_pp=1e38
        for i in range(nrepeat):
            t0=time()
Martin Reinecke's avatar
Martin Reinecke committed
26
            b=pypocketfft.c2c(a,nthreads=nthreads, forward=True)
27
28
            t1=time()
            tmin_pp = min(tmin_pp,t1-t0)
Martin Reinecke's avatar
Martin Reinecke committed
29
        a2=pypocketfft.c2c(b,inorm=2, forward=False)
30
        assert(_l2error(a,a2)<(2.5e-15 if tp=='c16' else 6e-7))
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        res.append(tmin_pp/tmin_np)
    plt.title("t(pypocketfft / numpy 1.17), {}D, {}, max_extent={}".format(ndim, str(tp), nmax))
    plt.xlabel("time ratio")
    plt.ylabel("counts")
    plt.hist(res,bins="auto")
    if filename != "":
        plt.savefig(filename)
    plt.show()

bench_nd_fftn(1, 8192, 1000, "c16", 10, "1d.png")
bench_nd_fftn(2, 2048, 100, "c16", 2, "2d.png")
bench_nd_fftn(3, 256, 100, "c16", 1, "3d.png")
bench_nd_fftn(1, 8192, 1000, "c8", 10, "1d_single.png")
bench_nd_fftn(2, 2048, 100, "c8", 2, "2d_single.png")
bench_nd_fftn(3, 256, 100, "c8", 1, "3d_single.png")