bench.py 3.12 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import pypocketfft
import pyfftw
from time import time
import matplotlib.pyplot as plt
import math

nthr = 1


def _l2error(a, b):
    return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))


def prime_factorize(n):
    factors = []
    number = math.fabs(n)

    while number > 1:
        factor = get_next_prime_factor(number)
        factors.append(factor)
        number /= factor

    if n < -1:  # If we'd check for < 0, -1 would give us trouble
        factors[0] = -factors[0]

    return tuple(factors)


def get_next_prime_factor(n):
    if n % 2 == 0:
        return 2

    # Not 'good' [also] checking non-prime numbers I guess?
    # But the alternative, creating a list of prime numbers,
    # wouldn't it be more demanding? Process of creating it.
    for x in range(3, int(math.ceil(math.sqrt(n)) + 1), 2):
        if n % x == 0:
            return x
    return int(n)


def measure_fftw(a, nrepeat):
    import pyfftw
    f1 = pyfftw.empty_aligned(a.shape, dtype=a.dtype)
    f1[()] = a
    f2 = pyfftw.empty_aligned(a.shape, dtype=a.dtype)
    fftw = pyfftw.FFTW(f1, f2, flags=('FFTW_MEASURE',), threads=nthr)
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        fftw()
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def measure_fftw_np_interface(a, nrepeat):
    import pyfftw
    pyfftw.interfaces.cache.enable()
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        b = pyfftw.interfaces.numpy_fft.fftn(a)
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def measure_pypocketfft(a, nrepeat):
    import pypocketfft as ppf
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        b = ppf.c2c(a, forward=True, nthreads=nthr)
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def measure_scipy_fftpack(a, nrepeat):
    import scipy.fftpack
    tmin = 1e38
    for i in range(nrepeat):
        t0 = time()
        b = scipy.fftpack.fftn(a)
        t1 = time()
        tmin = min(tmin, t1-t0)
    return tmin


def bench_nd(ndim, nmax, ntry, tp, funcs, nrepeat, ttl="", filename=""):
    results = [[] for i in range(len(funcs))]
    for n in range(ntry):
        print(n, ntry)
        shp = np.random.randint(nmax//3, nmax+1, ndim)
        a = (np.random.rand(*shp) + 1j*np.random.rand(*shp)).astype(tp)
        for func, res in zip(funcs, results):
            res.append(func(a, nrepeat))
    results = np.array(results)
    plt.title("{}: {}D, {}, max_extent={}".format(
        ttl, ndim, str(tp), nmax))
    plt.xlabel("time ratio")
    plt.ylabel("counts")
    plt.hist(results[0, :]/results[1, :], bins="auto")
    if filename != "":
        plt.savefig(filename)
    plt.show()


funcs = (measure_pypocketfft, measure_fftw_np_interface)
ttl = "pypocketfft/fftw_numpy_interface"
bench_nd(1, 8192, 100, "c16", funcs, 10, ttl, "1d.png")
bench_nd(2, 2048, 100, "c16", funcs, 2, ttl, "2d.png")
bench_nd(3, 256, 100, "c16", funcs, 2, ttl, "3d.png")
bench_nd(1, 8192, 100, "c8", funcs, 10, ttl, "1d_single.png")
bench_nd(2, 2048, 100, "c8", funcs, 2, ttl, "2d_single.png")
bench_nd(3, 256, 100, "c8", funcs, 2, ttl, "3d_single.png")