From 6a6efcd9a6b8b824c70130a6a8dbbddac7c73e94 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Mon, 23 Apr 2018 11:57:08 +0200 Subject: [PATCH] cosmetics --- demos/krylov_sampling.py | 12 +++++++++--- nifty4/library/krylov_sampling.py | 20 +++++++++----------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py index 65ca19dfe..b5cb5b62e 100644 --- a/demos/krylov_sampling.py +++ b/demos/krylov_sampling.py @@ -34,6 +34,7 @@ D_inv = R_p.adjoint * N.inverse * R_p + S.inverse history = [] + def sample(D_inv, S, j, N_samps, N_iter): global history space = D_inv.domain @@ -87,8 +88,12 @@ plt.close() pltdict = {'alpha': .3, 'linewidth': .2} for i in range(N_samps): if i == 0: - plt.plot(sky(samps_old[i]).val, color='b', **pltdict, label='Traditional samples (residuals)') - plt.plot(sky(samps[i]).val, color='r', **pltdict, label='Krylov samples (residuals)') + plt.plot(sky(samps_old[i]).val, color='b', + label='Traditional samples (residuals)', + **pltdict) + plt.plot(sky(samps[i]).val, color='r', + label='Krylov samples (residuals)', + **pltdict) else: plt.plot(sky(samps_old[i]).val, color='b', **pltdict) plt.plot(sky(samps[i]).val, color='r', **pltdict) @@ -105,7 +110,8 @@ for i in range(N_samps): plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty') plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--') plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt( - D_hat_new / N_samps), facecolor='0.5', alpha=0.5, label='Krylov unvertainty') + D_hat_new / N_samps), facecolor='0.5', alpha=0.5, + label='Krylov uncertainty') plt.plot((s_x - m_x).val, color='k', label='signal - mean') plt.legend() plt.savefig('Krylov_uncertainty.png') diff --git a/nifty4/library/krylov_sampling.py b/nifty4/library/krylov_sampling.py index a0a7b3545..ced72e8ec 100644 --- a/nifty4/library/krylov_sampling.py +++ b/nifty4/library/krylov_sampling.py @@ -20,9 +20,10 @@ from numpy import sqrt from numpy.random import randn -def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10, name=None): +def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10, + name=None): """ - Generates inverse samples from a curvature D + Generates inverse samples from a curvature D. This algorithm iteratively generates samples from a curvature D by applying conjugate gradient steps and resampling the curvature in search direction. @@ -52,25 +53,22 @@ def generate_krylov_samples(D_inv, S, j=None, N_samps=1, N_iter=10, name=None): D_inv(x) = j and the second entry are a list of samples from D_inv.inverse """ - if j is None: - j = S.draw_sample(from_inverse=True) + j = S.draw_sample(from_inverse=True) if j is None else j x = S.draw_sample() r = j.copy() p = r.copy() d = p.vdot(D_inv(p)) - y = [] - for i in range(N_samps): - y += [S.draw_sample()] - for k in range(1, 1 + N_iter): - gamma = r.vdot(r) / d + y = [S.draw_sample() for _ in range(N_samps)] + for k in range(1, 1+N_iter): + gamma = r.vdot(r)/d if gamma == 0.: break - x += gamma * p + x += gamma*p for i in range(N_samps): y[i] -= p.vdot(D_inv(y[i])) * p / d y[i] += randn() / sqrt(d) * p r_new = r - gamma * D_inv(p) - beta = r_new.vdot(r_new) / (r.vdot(r)) + beta = r_new.vdot(r_new) / r.vdot(r) r = r_new p = r + beta * p d = p.vdot(D_inv(p)) -- GitLab