There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

bench_gridder.py 1.81 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
from time import time

import matplotlib.pyplot as plt
import numpy as np

import nifty5 as ift

np.random.seed(40)

N0s, a0s, b0s, c0s = [], [], [], []

12 13 14
for ii in range(10, 26):
    nu = 2048
    nv = 2048
Martin Reinecke's avatar
Martin Reinecke committed
15
    N = int(2**ii)
16 17
    print('N = {}'.format(N))

18
    uvw = np.random.rand(N, 3) - 0.5
19 20 21 22 23 24
    vis = np.random.randn(N) + 1j*np.random.randn(N)

    uvspace = ift.RGSpace((nu, nv))

    visspace = ift.UnstructuredDomain(N)

25
    img = np.random.randn(nu*nv)
26 27 28 29
    img = img.reshape((nu, nv))
    img = ift.from_global_data(uvspace, img)

    t0 = time()
30 31 32
    GM = ift.GridderMaker(uvspace, eps=1e-7, uvw=uvw,
                          channel_fact=np.array([1.]),
                          flags=np.zeros((N,1), dtype=np.bool))
33
    vis = ift.from_global_data(visspace, vis)
Martin Reinecke's avatar
Martin Reinecke committed
34
    op = GM.getFull().adjoint
35 36 37 38 39
    t1 = time()
    op(img).to_global_data()
    t2 = time()
    op.adjoint(vis).to_global_data()
    t3 = time()
40
    print(t2-t1, t3-t2)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    N0s.append(N)
    a0s.append(t1 - t0)
    b0s.append(t2 - t1)
    c0s.append(t3 - t2)

print('Measure rest operator')
sc = ift.StatCalculator()
op = GM.getRest().adjoint
for _ in range(10):
    t0 = time()
    res = op(img)
    sc.add(time() - t0)
t_fft = sc.mean
print('FFT shape', res.shape)

plt.scatter(N0s, a0s, label='Gridder mr')
plt.legend()
Martin Reinecke's avatar
Martin Reinecke committed
58
# no idea why this is necessary, but if it is omitted, the range is wrong
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
59
plt.ylim(min(a0s), max(a0s))
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
plt.ylabel('time [s]')
plt.title('Initialization')
plt.loglog()
plt.savefig('bench0.png')
plt.close()

plt.scatter(N0s, b0s, color='k', marker='^', label='Gridder mr times')
plt.scatter(N0s, c0s, color='k', label='Gridder mr adjoint times')
plt.axhline(sc.mean, label='FFT')
plt.axhline(sc.mean + np.sqrt(sc.var))
plt.axhline(sc.mean - np.sqrt(sc.var))
plt.legend()
plt.ylabel('time [s]')
plt.title('Apply')
plt.loglog()
plt.savefig('bench1.png')
plt.close()