From 6a6efcd9a6b8b824c70130a6a8dbbddac7c73e94 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Mon, 23 Apr 2018 11:57:08 +0200
Subject: [PATCH] cosmetics

---
 demos/krylov_sampling.py          | 12 +++++++++---
 nifty4/library/krylov_sampling.py | 20 +++++++++-----------
 2 files changed, 18 insertions(+), 14 deletions(-)

diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py
index 65ca19dfe..b5cb5b62e 100644
--- a/demos/krylov_sampling.py
+++ b/demos/krylov_sampling.py
@@ -34,6 +34,7 @@ D_inv = R_p.adjoint * N.inverse * R_p + S.inverse
 
 history = []
 
+
 def sample(D_inv, S, j,  N_samps, N_iter):
     global history
     space = D_inv.domain
@@ -87,8 +88,12 @@ plt.close()
 pltdict = {'alpha': .3, 'linewidth': .2}
 for i in range(N_samps):
     if i == 0:
-        plt.plot(sky(samps_old[i]).val, color='b', **pltdict, label='Traditional samples (residuals)')
-        plt.plot(sky(samps[i]).val, color='r', **pltdict, label='Krylov samples (residuals)')
+        plt.plot(sky(samps_old[i]).val, color='b',
+                 label='Traditional samples (residuals)',
+                 **pltdict)
+        plt.plot(sky(samps[i]).val, color='r',
+                 label='Krylov samples (residuals)',
+                 **pltdict)
     else:
         plt.plot(sky(samps_old[i]).val, color='b', **pltdict)
         plt.plot(sky(samps[i]).val, color='r', **pltdict)
@@ -105,7 +110,8 @@ for i in range(N_samps):
 plt.plot(np.sqrt(D_hat_old / N_samps), 'r--', label='Traditional uncertainty')
 plt.plot(-np.sqrt(D_hat_old / N_samps), 'r--')
 plt.fill_between(range(len(D_hat_new)), -np.sqrt(D_hat_new / N_samps), np.sqrt(
-    D_hat_new / N_samps), facecolor='0.5', alpha=0.5, label='Krylov unvertainty')
+    D_hat_new / N_samps), facecolor='0.5', alpha=0.5,
+    label='Krylov uncertainty')
 plt.plot((s_x - m_x).val, color='k', label='signal - mean')
 plt.legend()
 plt.savefig('Krylov_uncertainty.png')
diff --git a/nifty4/library/krylov_sampling.py b/nifty4/library/krylov_sampling.py
index a0a7b3545..ced72e8ec 100644
--- a/nifty4/library/krylov_sampling.py
+++ b/nifty4/library/krylov_sampling.py
@@ -20,9 +20,10 @@ from numpy import sqrt
 from numpy.random import randn
 
 
-def generate_krylov_samples(D_inv, S, j=None,  N_samps=1, N_iter=10, name=None):
+def generate_krylov_samples(D_inv, S, j=None,  N_samps=1, N_iter=10,
+                            name=None):
     """
-    Generates inverse samples from a curvature D
+    Generates inverse samples from a curvature D.
     This algorithm iteratively generates samples from
     a curvature D by applying conjugate gradient steps
     and resampling the curvature in search direction.
@@ -52,25 +53,22 @@ def generate_krylov_samples(D_inv, S, j=None,  N_samps=1, N_iter=10, name=None):
             D_inv(x) = j
         and the second entry are a list of samples from D_inv.inverse
     """
-    if j is None:
-        j = S.draw_sample(from_inverse=True)
+    j = S.draw_sample(from_inverse=True) if j is None else j
     x = S.draw_sample()
     r = j.copy()
     p = r.copy()
     d = p.vdot(D_inv(p))
-    y = []
-    for i in range(N_samps):
-        y += [S.draw_sample()]
-    for k in range(1, 1 + N_iter):
-        gamma = r.vdot(r) / d
+    y = [S.draw_sample() for _ in range(N_samps)]
+    for k in range(1, 1+N_iter):
+        gamma = r.vdot(r)/d
         if gamma == 0.:
             break
-        x += gamma * p
+        x += gamma*p
         for i in range(N_samps):
             y[i] -= p.vdot(D_inv(y[i])) * p / d
             y[i] += randn() / sqrt(d) * p
         r_new = r - gamma * D_inv(p)
-        beta = r_new.vdot(r_new) / (r.vdot(r))
+        beta = r_new.vdot(r_new) / r.vdot(r)
         r = r_new
         p = r + beta * p
         d = p.vdot(D_inv(p))
-- 
GitLab