From 4a35010ca362c855efb992b334bd129d25d31bb7 Mon Sep 17 00:00:00 2001
From: Reimar Leike <reimar@leike.name>
Date: Mon, 4 Jun 2018 15:32:45 +0200
Subject: [PATCH] introduced a new sampling routine that does not have the
systematic errors of the old krylov samples but is slower
---
demos/krylov_sampling.py | 5 +-
nifty4/library/krylov_sampling.py | 121 +++++++++++++++---------------
2 files changed, 65 insertions(+), 61 deletions(-)
diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py
index 4987437ae..afb0822f2 100644
--- a/demos/krylov_sampling.py
+++ b/demos/krylov_sampling.py
@@ -37,10 +37,11 @@ D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse
N_samps = 200
N_iter = 100
IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
-m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
-m_x = sky(m)
+samps = ift.library.generate_krylov_samples(D_inv, S, N_samps, IC)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
+m = curv.inverse_times(j)
+m_x = sky(m)
samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
diff --git a/nifty4/library/krylov_sampling.py b/nifty4/library/krylov_sampling.py
index d83a8a8f2..3196d0c52 100644
--- a/nifty4/library/krylov_sampling.py
+++ b/nifty4/library/krylov_sampling.py
@@ -20,26 +20,23 @@ import numpy as np
from ..minimization.quadratic_energy import QuadraticEnergy
-def generate_krylov_samples(D_inv, S, j, N_samps, controller):
+def generate_krylov_samples(D_inv, S, N_samps, controller):
"""
Generates inverse samples from a curvature D.
This algorithm iteratively generates samples from
a curvature D by applying conjugate gradient steps
and resampling the curvature in search direction.
+ It is basically just a more stable version of
+ Wiener Filter samples
Parameters
----------
- D_inv : EndomorphicOperator
- The curvature which will be the inverse of the covarianc
+ D_inv : WienerFilterCurvature
+ The curvature which will be the inverse of the covariance
of the generated samples
S : EndomorphicOperator (from which one can sample)
A prior covariance operator which is used to generate prior
samples that are then iteratively updated
- j : Field, optional
- A Field to which the inverse of D_inv is applied. The solution
- of this matrix inversion problem is a side product of generating
- the samples.
- If not supplied, it is sampled from the inverse prior.
N_samps : Int
How many samples to generate.
controller : IterationController
@@ -47,56 +44,62 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
Returns
-------
- (solution, samples) : A tuple of a field 'solution' and a list of fields
- 'samples'. The first entry of the tuple is the solution x to
- D_inv(x) = j
- and the second entry are a list of samples from D_inv.inverse
+ samples : a list of samples from D_inv.inverse
"""
- # RL FIXME: make consistent with complex numbers
- j = S.draw_sample(from_inverse=True) if j is None else j
- energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
- y = [S.draw_sample() for _ in range(N_samps)]
-
- status = controller.start(energy)
- if status != controller.CONTINUE:
- return energy.position, y
-
- r = energy.gradient
- d = r.copy()
-
- previous_gamma = r.vdot(r).real
- if previous_gamma == 0:
- return energy.position, y
-
- while True:
- q = energy.curvature(d)
- ddotq = d.vdot(q).real
- if ddotq == 0.:
- logger.error("Error: ConjugateGradient: ddotq==0.")
- return energy.position, y
- alpha = previous_gamma/ddotq
-
- if alpha < 0:
- logger.error("Error: ConjugateGradient: alpha<0.")
- return energy.position, y
-
- for i in range(len(y)):
- y[i] += (np.random.randn()*np.sqrt(ddotq) - y[i].vdot(q))/ddotq * d
-
- q *= -alpha
- r = r + q
-
- energy = energy.at_with_grad(energy.position - alpha*d, r)
-
- gamma = r.vdot(r).real
- if gamma == 0:
- return energy.position, y
-
- status = controller.check(energy)
+ samples = []
+ for i in range(N_samps):
+ x0 = S.draw_sample()
+ y = x0*0
+ j = y*0
+ #j = y
+ energy = QuadraticEnergy(x0, D_inv, j)
+
+ status = controller.start(energy)
if status != controller.CONTINUE:
- return energy.position, y
-
- d *= max(0, gamma/previous_gamma)
- d += r
-
- previous_gamma = gamma
+ samples += [y]
+ break
+
+ r = energy.gradient
+ d = r.copy()
+
+ previous_gamma = r.vdot(r).real
+ if previous_gamma == 0:
+ samples += [y+energy.position]
+ break
+
+ while True:
+ q = energy.curvature(d)
+ ddotq = d.vdot(q).real
+ if ddotq == 0.:
+ logger.error("Error: ConjugateGradient: ddotq==0.")
+ samples += [y+energy.position]
+ break
+ alpha = previous_gamma/ddotq
+
+ if alpha < 0:
+ logger.error("Error: ConjugateGradient: alpha<0.")
+ samples += [y+energy.position]
+ break
+
+ y += (np.random.randn()*np.sqrt(ddotq) )/ddotq * d
+
+ q *= -alpha
+ r = r + q
+
+ energy = energy.at_with_grad(energy.position - alpha*d, r)
+
+ gamma = r.vdot(r).real
+ if gamma == 0:
+ samples += [y+energy.position]
+ break
+
+ status = controller.check(energy)
+ if status != controller.CONTINUE:
+ samples += [y+energy.position]
+ break
+
+ d *= max(0, gamma/previous_gamma)
+ d += r
+
+ previous_gamma = gamma
+ return samples
--
GitLab