bench.py 3.08 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import pypocketfft
from time import time
import matplotlib.pyplot as plt
import math

nthr = 1


def _l2error(a, b):
    return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))


def prime_factorize(n):
    factors = []
    number = math.fabs(n)

    while number > 1:
        factor = get_next_prime_factor(number)
        factors.append(factor)
        number /= factor

    if n < -1:  # If we'd check for < 0, -1 would give us trouble
        factors[0] = -factors[0]

    return tuple(factors)


def get_next_prime_factor(n):
    if n % 2 == 0:
        return 2

    # Not 'good' [also] checking non-prime numbers I guess?
    # But the alternative, creating a list of prime numbers,
    # wouldn't it be more demanding? Process of creating it.
    for x in range(3, int(math.ceil(math.sqrt(n)) + 1), 2):
        if n % x == 0:
            return x
    return int(n)


def measure_fftw(a, nrepeat):
    import pyfftw
    f1 = pyfftw.empty_aligned(a.shape, dtype=a.dtype)
    f1[()] = a
    f2 = pyfftw.empty_aligned(a.shape, dtype=a.dtype)
    fftw = pyfftw.FFTW(f1, f2, flags=('FFTW_MEASURE',), threads=nthr)
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        fftw()
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def measure_fftw_np_interface(a, nrepeat):
    import pyfftw
    pyfftw.interfaces.cache.enable()
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        b = pyfftw.interfaces.numpy_fft.fftn(a)
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def measure_pypocketfft(a, nrepeat):
    import pypocketfft as ppf
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        b = ppf.c2c(a, forward=True, nthreads=nthr)
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def measure_scipy_fftpack(a, nrepeat):
    import scipy.fftpack
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        b = scipy.fftpack.fftn(a)
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def bench_nd(ndim, nmax, ntry, tp, funcs, nrepeat, ttl="", filename=""):
    results = [[] for i in range(len(funcs))]
    for n in range(ntry):
        print(n, ntry)
        shp = np.random.randint(nmax//3, nmax+1, ndim)
        a = (np.random.rand(*shp) + 1j*np.random.rand(*shp)).astype(tp)
        for func, res in zip(funcs, results):
            res.append(func(a, nrepeat))
    results = np.array(results)
    plt.title("{}: {}D, {}, max_extent={}".format(
        ttl, ndim, str(tp), nmax))
    plt.xlabel("time ratio")
    plt.ylabel("counts")
    plt.hist(results[0, :]/results[1, :], bins="auto")
    if filename != "":
        plt.savefig(filename)
    plt.show()


110
funcs = (measure_pypocketfft, measure_fftw)
111
ttl = "pypocketfft/FFTW()"
Martin Reinecke's avatar
Martin Reinecke committed
112
113
114
115
116
117
bench_nd(1, 8192, 100, "c16", funcs, 10, ttl, "1d.png")
bench_nd(2, 2048, 100, "c16", funcs, 2, ttl, "2d.png")
bench_nd(3, 256, 100, "c16", funcs, 2, ttl, "3d.png")
bench_nd(1, 8192, 100, "c8", funcs, 10, ttl, "1d_single.png")
bench_nd(2, 2048, 100, "c8", funcs, 2, ttl, "2d_single.png")
bench_nd(3, 256, 100, "c8", funcs, 2, ttl, "3d_single.png")