bench_gridder.py 1.66 KB
Newer Older
1 2 3 4 5
from time import time

import matplotlib.pyplot as plt
import numpy as np

Martin Reinecke's avatar
5->6  
Martin Reinecke committed
6
import nifty6 as ift
7 8 9 10 11

np.random.seed(40)

N0s, a0s, b0s, c0s = [], [], [], []

12
for ii in range(10, 26):
Martin Reinecke's avatar
Martin Reinecke committed
13 14
    nu = 1024
    nv = 1024
Martin Reinecke's avatar
Martin Reinecke committed
15
    N = int(2**ii)
16
    print('N = {}'.format(N))
17 18 19

    uv = np.random.rand(N, 2) - 0.5
    vis = np.random.randn(N) + 1j*np.random.randn(N)
20 21 22

    uvspace = ift.RGSpace((nu, nv))

23
    visspace = ift.UnstructuredDomain(N)
24

25
    img = np.random.randn(nu*nv)
26
    img = img.reshape((nu, nv))
Martin Reinecke's avatar
Martin Reinecke committed
27
    img = ift.makeField(uvspace, img)
28 29

    t0 = time()
30
    GM = ift.GridderMaker(uvspace, eps=1e-7, uv=uv)
Martin Reinecke's avatar
Martin Reinecke committed
31
    vis = ift.makeField(visspace, vis)
Martin Reinecke's avatar
Martin Reinecke committed
32
    op = GM.getFull().adjoint
33
    t1 = time()
Martin Reinecke's avatar
Martin Reinecke committed
34
    op(img).val
35
    t2 = time()
Martin Reinecke's avatar
Martin Reinecke committed
36
    op.adjoint(vis).val
37
    t3 = time()
38
    print(t2-t1, t3-t2)
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    N0s.append(N)
    a0s.append(t1 - t0)
    b0s.append(t2 - t1)
    c0s.append(t3 - t2)

print('Measure rest operator')
sc = ift.StatCalculator()
op = GM.getRest().adjoint
for _ in range(10):
    t0 = time()
    res = op(img)
    sc.add(time() - t0)
t_fft = sc.mean
print('FFT shape', res.shape)

plt.scatter(N0s, a0s, label='Gridder mr')
plt.legend()
Martin Reinecke's avatar
Martin Reinecke committed
56
# no idea why this is necessary, but if it is omitted, the range is wrong
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
57
plt.ylim(min(a0s), max(a0s))
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
plt.ylabel('time [s]')
plt.title('Initialization')
plt.loglog()
plt.savefig('bench0.png')
plt.close()

plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times')
plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times')
plt.axhline(sc.mean, label='FFT')
plt.axhline(sc.mean + np.sqrt(sc.var))
plt.axhline(sc.mean - np.sqrt(sc.var))
plt.legend()
plt.ylabel('time [s]')
plt.title('Apply')
plt.loglog()
plt.savefig('bench1.png')
plt.close()