...
 
Commits (1)
......@@ -37,10 +37,11 @@ D_inv = ift.SandwichOperator.make(R_p, N.inverse) + S.inverse
N_samps = 200
N_iter = 100
IC = ift.GradientNormController(tol_abs_gradnorm=1e-3, iteration_limit=N_iter)
m, samps = ift.library.generate_krylov_samples(D_inv, S, j, N_samps, IC)
m_x = sky(m)
samps = ift.library.generate_krylov_samples(D_inv, S, N_samps, IC)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter)
m = curv.inverse_times(j)
m_x = sky(m)
samps_old = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
plt.plot(d.to_global_data(), '+', label="data", alpha=.5)
......
......@@ -20,26 +20,23 @@ import numpy as np
from ..minimization.quadratic_energy import QuadraticEnergy
def generate_krylov_samples(D_inv, S, j, N_samps, controller):
def generate_krylov_samples(D_inv, S, N_samps, controller):
"""
Generates inverse samples from a curvature D.
This algorithm iteratively generates samples from
a curvature D by applying conjugate gradient steps
and resampling the curvature in search direction.
It is basically just a more stable version of
Wiener Filter samples
Parameters
----------
D_inv : EndomorphicOperator
The curvature which will be the inverse of the covarianc
D_inv : WienerFilterCurvature
The curvature which will be the inverse of the covariance
of the generated samples
S : EndomorphicOperator (from which one can sample)
A prior covariance operator which is used to generate prior
samples that are then iteratively updated
j : Field, optional
A Field to which the inverse of D_inv is applied. The solution
of this matrix inversion problem is a side product of generating
the samples.
If not supplied, it is sampled from the inverse prior.
N_samps : Int
How many samples to generate.
controller : IterationController
......@@ -47,56 +44,62 @@ def generate_krylov_samples(D_inv, S, j, N_samps, controller):
Returns
-------
(solution, samples) : A tuple of a field 'solution' and a list of fields
'samples'. The first entry of the tuple is the solution x to
D_inv(x) = j
and the second entry are a list of samples from D_inv.inverse
samples : a list of samples from D_inv.inverse
"""
# RL FIXME: make consistent with complex numbers
j = S.draw_sample(from_inverse=True) if j is None else j
energy = QuadraticEnergy(j.empty_copy().fill(0.), D_inv, j)
y = [S.draw_sample() for _ in range(N_samps)]
status = controller.start(energy)
if status != controller.CONTINUE:
return energy.position, y
r = energy.gradient
d = r.copy()
previous_gamma = r.vdot(r).real
if previous_gamma == 0:
return energy.position, y
while True:
q = energy.curvature(d)
ddotq = d.vdot(q).real
if ddotq == 0.:
logger.error("Error: ConjugateGradient: ddotq==0.")
return energy.position, y
alpha = previous_gamma/ddotq
if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.")
return energy.position, y
for i in range(len(y)):
y[i] += (np.random.randn()*np.sqrt(ddotq) - y[i].vdot(q))/ddotq * d
q *= -alpha
r = r + q
energy = energy.at_with_grad(energy.position - alpha*d, r)
gamma = r.vdot(r).real
if gamma == 0:
return energy.position, y
status = controller.check(energy)
samples = []
for i in range(N_samps):
x0 = S.draw_sample()
y = x0*0
j = y*0
#j = y
energy = QuadraticEnergy(x0, D_inv, j)
status = controller.start(energy)
if status != controller.CONTINUE:
return energy.position, y
d *= max(0, gamma/previous_gamma)
d += r
previous_gamma = gamma
samples += [y]
break
r = energy.gradient
d = r.copy()
previous_gamma = r.vdot(r).real
if previous_gamma == 0:
samples += [y+energy.position]
break
while True:
q = energy.curvature(d)
ddotq = d.vdot(q).real
if ddotq == 0.:
logger.error("Error: ConjugateGradient: ddotq==0.")
samples += [y+energy.position]
break
alpha = previous_gamma/ddotq
if alpha < 0:
logger.error("Error: ConjugateGradient: alpha<0.")
samples += [y+energy.position]
break
y += (np.random.randn()*np.sqrt(ddotq) )/ddotq * d
q *= -alpha
r = r + q
energy = energy.at_with_grad(energy.position - alpha*d, r)
gamma = r.vdot(r).real
if gamma == 0:
samples += [y+energy.position]
break
status = controller.check(energy)
if status != controller.CONTINUE:
samples += [y+energy.position]
break
d *= max(0, gamma/previous_gamma)
d += r
previous_gamma = gamma
return samples