bench_gridder.py 1.75 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
from time import time

import matplotlib.pyplot as plt
import numpy as np

import nifty5 as ift

ift.fft.enable_fftw()
np.random.seed(40)

N0s, a0s, b0s, c0s = [], [], [], []

Martin Reinecke's avatar
Martin Reinecke committed
13
for ii in range(10, 23):
14 15
    nu = 1024
    nv = 1024
Martin Reinecke's avatar
Martin Reinecke committed
16
    N = int(2**ii)
17 18 19 20 21 22 23 24 25
    print('N = {}'.format(N))

    uv = np.random.rand(N, 2) - 0.5
    vis = np.random.randn(N) + 1j*np.random.randn(N)

    uvspace = ift.RGSpace((nu, nv))

    visspace = ift.UnstructuredDomain(N)

26
    img = np.random.randn(nu*nv)
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    img = img.reshape((nu, nv))
    img = ift.from_global_data(uvspace, img)

    t0 = time()
    GM = ift.GridderMaker(uvspace, eps=1e-7)
    idx = GM.getReordering(uv)
    uv = uv[idx]
    vis = vis[idx]
    vis = ift.from_global_data(visspace, vis)
    op = GM.getFull(uv).adjoint
    t1 = time()
    op(img).to_global_data()
    t2 = time()
    op.adjoint(vis).to_global_data()
    t3 = time()
    N0s.append(N)
    a0s.append(t1 - t0)
    b0s.append(t2 - t1)
    c0s.append(t3 - t2)

print('Measure rest operator')
sc = ift.StatCalculator()
op = GM.getRest().adjoint
for _ in range(10):
    t0 = time()
    res = op(img)
    sc.add(time() - t0)
t_fft = sc.mean
print('FFT shape', res.shape)

plt.scatter(N0s, a0s, label='Gridder mr')
plt.legend()
Martin Reinecke's avatar
Martin Reinecke committed
59
# no idea why this is necessary, but if it is omitted, the range is wrong
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
60
plt.ylim(min(a0s), max(a0s))
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
plt.ylabel('time [s]')
plt.title('Initialization')
plt.loglog()
plt.savefig('bench0.png')
plt.close()

plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times')
plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times')
plt.axhline(sc.mean, label='FFT')
plt.axhline(sc.mean + np.sqrt(sc.var))
plt.axhline(sc.mean - np.sqrt(sc.var))
plt.legend()
plt.ylabel('time [s]')
plt.title('Apply')
plt.loglog()
plt.savefig('bench1.png')
plt.close()