bench_gridder.py 2.21 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
from time import time

import matplotlib.pyplot as plt
import numpy as np

import nifty5 as ift

ift.fft.enable_fftw()
np.random.seed(40)

N0s, a0s, b0s, c0s = [], [], [], []
N1s, a1s, b1s, c1s = [], [], [], []

Martin Reinecke's avatar
Martin Reinecke committed
14
for ii in range(10, 23):
15 16
    nu = 1024
    nv = 1024
Martin Reinecke's avatar
Martin Reinecke committed
17
    N = int(2**ii)
18 19 20 21 22 23 24 25 26
    print('N = {}'.format(N))

    uv = np.random.rand(N, 2) - 0.5
    vis = np.random.randn(N) + 1j*np.random.randn(N)

    uvspace = ift.RGSpace((nu, nv))

    visspace = ift.UnstructuredDomain(N)

27
    img = np.random.randn(nu*nv)
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    img = img.reshape((nu, nv))
    img = ift.from_global_data(uvspace, img)

    t0 = time()
    GM = ift.GridderMaker(uvspace, eps=1e-7)
    idx = GM.getReordering(uv)
    uv = uv[idx]
    vis = vis[idx]
    vis = ift.from_global_data(visspace, vis)
    op = GM.getFull(uv).adjoint
    t1 = time()
    op(img).to_global_data()
    t2 = time()
    op.adjoint(vis).to_global_data()
    t3 = time()
    N0s.append(N)
    a0s.append(t1 - t0)
    b0s.append(t2 - t1)
    c0s.append(t3 - t2)

    t0 = time()
    op = ift.NFFT(uvspace, uv)
    t1 = time()
    op(img).to_global_data()
    t2 = time()
    op.adjoint(vis).to_global_data()
    t3 = time()
    N1s.append(N)
    a1s.append(t1 - t0)
    b1s.append(t2 - t1)
    c1s.append(t3 - t2)

print('Measure rest operator')
sc = ift.StatCalculator()
op = GM.getRest().adjoint
for _ in range(10):
    t0 = time()
    res = op(img)
    sc.add(time() - t0)
t_fft = sc.mean
print('FFT shape', res.shape)

plt.scatter(N0s, a0s, label='Gridder mr')
plt.scatter(N1s, a1s, marker='^', label='NFFT')
plt.legend()
Martin Reinecke's avatar
Martin Reinecke committed
73 74
# no idea why this is necessary, but if it is omitted, the range is wrong
plt.ylim(min(a0s+a1s), max(a0s+a1s))
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
plt.ylabel('time [s]')
plt.title('Initialization')
plt.loglog()
plt.savefig('bench0.png')
plt.close()

plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times')
plt.scatter(N1s, b1s, color='r', marker='^', label='NFFT times')
plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times')
plt.scatter(N1s, c1s, color='r', label='NFFT adjoint times')
plt.axhline(sc.mean, label='FFT')
plt.axhline(sc.mean + np.sqrt(sc.var))
plt.axhline(sc.mean - np.sqrt(sc.var))
plt.legend()
plt.ylabel('time [s]')
plt.title('Apply')
plt.loglog()
plt.savefig('bench1.png')
plt.close()