Commit b3b3a72a authored by Philipp Arras's avatar Philipp Arras
Browse files

Do not pass inverters around anymore where it is not necessary

parent 0ac48f64
Pipeline #31375 failed with stages
in 4 minutes and 14 seconds
......@@ -169,10 +169,9 @@
"def Curvature(R, N, Sh):\n",
" IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n",
" inverter = ift.ConjugateGradient(controller=IC)\n",
" # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n",
" # helper methods.\n",
" return ift.library.WienerFilterCurvature(R,N,Sh,inverter, sampling_inverter=inverter)"
" return ift.library.WienerFilterCurvature(R,N,Sh,iteration_controller=IC)"
]
},
{
......
%% Cell type:markdown id: tags:
# A NIFTy demonstration
%% Cell type:markdown id: tags:
## IFT: Big Picture
IFT starting point:
$$d = Rs+n$$
Typically, $s$ is a continuous field, $d$ a discrete data vector. Particularly, $R$ is not invertible.
IFT aims at **inverting** the above uninvertible problem in the **best possible way** using Bayesian statistics.
## NIFTy
NIFTy (Numerical Information Field Theory) is a Python framework in which IFT problems can be tackled easily.
Main Interfaces:
- **Spaces**: Cartesian, 2-Spheres (Healpix, Gauss-Legendre) and their respective harmonic spaces.
- **Fields**: Defined on spaces.
- **Operators**: Acting on fields.
%% Cell type:markdown id: tags:
## Wiener Filter: Formulae
### Assumptions
- $d=Rs+n$, $R$ linear operator.
- $\mathcal P (s) = \mathcal G (s,S)$, $\mathcal P (n) = \mathcal G (n,N)$ where $S, N$ are positive definite matrices.
### Posterior
The Posterior is given by:
$$\mathcal P (s|d) \propto P(s,d) = \mathcal G(d-Rs,N) \,\mathcal G(s,S) \propto \mathcal G (m,D) $$
where
$$\begin{align}
m &= Dj \\
D^{-1}&= (S^{-1} +R^\dagger N^{-1} R )\\
j &= R^\dagger N^{-1} d
\end{align}$$
Let us implement this in NIFTy!
%% Cell type:markdown id: tags:
## Wiener Filter: Example
- We assume statistical homogeneity and isotropy. Therefore the signal covariance $S$ is diagonal in harmonic space, and is described by a one-dimensional power spectrum, assumed here as $$P(k) = P_0\,\left(1+\left(\frac{k}{k_0}\right)^2\right)^{-\gamma /2},$$
with $P_0 = 0.2, k_0 = 5, \gamma = 4$.
- $N = 0.2 \cdot \mathbb{1}$.
- Number of data points $N_{pix} = 512$.
- reconstruction in harmonic space.
- Response operator:
$$R = FFT_{\text{harmonic} \rightarrow \text{position}}$$
%% Cell type:code id: tags:
``` python
N_pixels = 512 # Number of pixels
def pow_spec(k):
P0, k0, gamma = [.2, 5, 4]
return P0 / ((1. + (k/k0)**2)**(gamma / 2))
```
%% Cell type:markdown id: tags:
## Wiener Filter: Implementation
%% Cell type:markdown id: tags:
### Import Modules
%% Cell type:code id: tags:
``` python
import numpy as np
np.random.seed(40)
import nifty5 as ift
import matplotlib.pyplot as plt
%matplotlib inline
```
%% Cell type:markdown id: tags:
### Implement Propagator
%% Cell type:code id: tags:
``` python
def Curvature(R, N, Sh):
IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# helper methods.
return ift.library.WienerFilterCurvature(R,N,Sh,inverter, sampling_inverter=inverter)
return ift.library.WienerFilterCurvature(R,N,Sh,iteration_controller=IC)
```
%% Cell type:markdown id: tags:
### Conjugate Gradient Preconditioning
- $D$ is defined via:
$$D^{-1} = \mathcal S_h^{-1} + R^\dagger N^{-1} R.$$
In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: *Conjugate Gradient*).
<!--
- One can define the *condition number* of a non-singular and normal matrix $A$:
$$\kappa (A) := \frac{|\lambda_{\text{max}}|}{|\lambda_{\text{min}}|},$$
where $\lambda_{\text{max}}$ and $\lambda_{\text{min}}$ are the largest and smallest eigenvalue of $A$, respectively.
- The larger $\kappa$ the slower Conjugate Gradient.
- By default, conjugate gradient solves: $D^{-1} m = j$ for $m$, where $D^{-1}$ can be badly conditioned. If one knows a non-singular matrix $T$ for which $TD^{-1}$ is better conditioned, one can solve the equivalent problem:
$$\tilde A m = \tilde j,$$
where $\tilde A = T D^{-1}$ and $\tilde j = Tj$.
- In our case $S^{-1}$ is responsible for the bad conditioning of $D$ depending on the chosen power spectrum. Thus, we choose
$$T = \mathcal F^\dagger S_h^{-1} \mathcal F.$$
-->
%% Cell type:markdown id: tags:
### Generate Mock data
- Generate a field $s$ and $n$ with given covariances.
- Calculate $d$.
%% Cell type:code id: tags:
``` python
s_space = ift.RGSpace(N_pixels)
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space, target=s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)
# Fields and data
sh = Sh.draw_sample()
noiseless_data=R(sh)
noise_amplitude = np.sqrt(0.2)
N = ift.ScalingOperator(noise_amplitude**2, s_space)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
d = noiseless_data + n
j = R.adjoint_times(N.inverse_times(d))
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
```
%% Cell type:markdown id: tags:
### Run Wiener Filter
%% Cell type:code id: tags:
``` python
m = D(j)
```
%% Cell type:markdown id: tags:
### Signal Reconstruction
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = HT(sh).to_global_data()
m_data = HT(m).to_global_data()
d_data = d.to_global_data()
plt.figure(figsize=(15,10))
plt.plot(s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction",linewidth=3)
plt.title("Reconstruction")
plt.legend()
plt.show()
```
%% Cell type:code id: tags:
``` python
plt.figure(figsize=(15,10))
plt.plot(s_data - s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data - s_data, 'k.', label="Data")
plt.plot(m_data - s_data, 'k', label="Reconstruction",linewidth=3)
plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)
plt.title("Residuals")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
### Power Spectrum
%% Cell type:code id: tags:
``` python
s_power_data = ift.power_analyze(sh).to_global_data()
m_power_data = ift.power_analyze(m).to_global_data()
plt.figure(figsize=(15,10))
plt.loglog()
plt.xlim(1, int(N_pixels/2))
ymin = min(m_power_data)
plt.ylim(ymin, 1)
xs = np.arange(1,int(N_pixels/2),.1)
plt.plot(xs, pow_spec(xs), label="True Power Spectrum", color='k',alpha=0.5)
plt.plot(s_power_data, 'r', label="Signal")
plt.plot(m_power_data, 'k', label="Reconstruction")
plt.axhline(noise_amplitude**2 / N_pixels, color="k", linestyle='--', label="Noise level", alpha=.5)
plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)
plt.title("Power Spectrum")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
## Wiener Filter on Incomplete Data
%% Cell type:code id: tags:
``` python
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(noise_amplitude**2,s_space)
# R is defined below
# Fields
sh = Sh.draw_sample()
s = HT(sh)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
```
%% Cell type:markdown id: tags:
### Partially Lose Data
%% Cell type:code id: tags:
``` python
l = int(N_pixels * 0.2)
h = int(N_pixels * 0.2 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h] = 0
mask = ift.Field.from_global_data(s_space, mask)
R = ift.DiagonalOperator(mask)*HT
n = n.to_global_data()
n[l:h] = 0
n = ift.Field.from_global_data(s_space, n)
d = R(sh) + n
```
%% Cell type:code id: tags:
``` python
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
j = R.adjoint_times(N.inverse_times(d))
m = D(j)
```
%% Cell type:markdown id: tags:
### Compute Uncertainty
%% Cell type:code id: tags:
``` python
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200)
```
%% Cell type:markdown id: tags:
### Get data
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = s.to_global_data()
m_data = HT(m).to_global_data()
m_var_data = m_var.to_global_data()
uncertainty = np.sqrt(m_var_data)
d_data = d.to_global_data()
# Set lost data to NaN for proper plotting
d_data[d_data == 0] = np.nan
```
%% Cell type:code id: tags:
``` python
fig = plt.figure(figsize=(15,10))
plt.axvspan(l, h, facecolor='0.8',alpha=0.5)
plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0.5', alpha=0.5)
plt.plot(s_data, 'r', label="Signal", alpha=1, linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction", linewidth=3)
plt.title("Reconstruction of incomplete data")
plt.legend()
```
%% Cell type:markdown id: tags:
# 2d Example
%% Cell type:code id: tags:
``` python
N_pixels = 256 # Number of pixels
sigma2 = 2. # Noise variance
def pow_spec(k):
P0, k0, gamma = [.2, 2, 4]
return P0 * (1. + (k/k0)**2)**(-gamma/2)
s_space = ift.RGSpace([N_pixels, N_pixels])
```
%% Cell type:code id: tags:
``` python
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space,s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(sigma2,s_space)
# Fields and data
sh = Sh.draw_sample()
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
# Lose some data
l = int(N_pixels * 0.33)
h = int(N_pixels * 0.33 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h,l:h] = 0.
mask = ift.Field.from_global_data(s_space, mask)
R = ift.DiagonalOperator(mask)*HT
n = n.to_global_data()
n[l:h, l:h] = 0
n = ift.Field.from_global_data(s_space, n)
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
d = R(sh) + n
j = R.adjoint_times(N.inverse_times(d))
# Run Wiener filter
m = D(j)
# Uncertainty
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20)
# Get data
s_data = HT(sh).to_global_data()
m_data = HT(m).to_global_data()
m_var_data = m_var.to_global_data()
d_data = d.to_global_data()
uncertainty = np.sqrt(np.abs(m_var_data))
```
%% Cell type:code id: tags:
``` python
cm = ['magma', 'inferno', 'plasma', 'viridis'][1]
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
data = [s_data, d_data]
caption = ["Signal", "Data"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi,
vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:code id: tags:
``` python
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(3, 2, figsize=(15, 22.5))
sample = HT(curv.draw_sample(from_inverse=True)+m).to_global_data()
post_mean = (m_mean + HT(m)).to_global_data()
data = [s_data, m_data, post_mean, sample, s_data - m_data, uncertainty]
caption = ["Signal", "Reconstruction", "Posterior mean", "Sample", "Residuals", "Uncertainty Map"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi, vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:markdown id: tags:
### Is the uncertainty map reliable?
%% Cell type:code id: tags:
``` python
precise = (np.abs(s_data-m_data) < uncertainty)
print("Error within uncertainty map bounds: " + str(np.sum(precise) * 100 / N_pixels**2) + "%")
plt.figure(figsize=(15,10))
plt.imshow(precise.astype(float), cmap="brg")
plt.colorbar()
```
%% Cell type:markdown id: tags:
# Start Coding
## NIFTy Repository + Installation guide
https://gitlab.mpcdf.mpg.de/ift/NIFTy
NIFTy v5 **more or less stable!**
......
......@@ -85,15 +85,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
IC = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
for i in range(20):
power0 = Distributor(ift.exp(0.5*t0))
map0_energy = ift.library.NonlinearWienerFilterEnergy(
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S,
inverter=inverter)
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, IC)
# Minimization with chosen minimizer
map0_energy, convergence = minimizer(map0_energy)
......@@ -106,7 +104,8 @@ if __name__ == "__main__":
power0_energy = ift.library.NonlinearPowerEnergy(
position=t0, d=d, N=N, xi=m0, D=D0, ht=HT,
Instrument=MeasurementOperator, nonlinearity=nonlinearity,
Distributor=Distributor, sigma=1., samples=2, inverter=inverter)
Distributor=Distributor, sigma=1., samples=2,
iteration_controller=IC)
power0_energy = minimizer(power0_energy)[0]
......
......@@ -78,15 +78,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
IC = ift.GradientNormController(iteration_limit=500,
tol_abs_gradnorm=1e-3)
for i in range(20):
power0 = Distributor(ift.exp(0.5*t0))
map0_energy = ift.library.NonlinearWienerFilterEnergy(
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S,
inverter=inverter)
m0, d, MeasurementOperator, nonlinearity, HT, power0, N, S, IC)
# Minimization with chosen minimizer
map0_energy, convergence = minimizer(map0_energy)
......@@ -99,7 +97,7 @@ if __name__ == "__main__":
power0_energy = ift.library.NonlinearPowerEnergy(
position=t0, d=d, N=N, xi=m0, D=D0, ht=HT,
Instrument=MeasurementOperator, nonlinearity=nonlinearity,
Distributor=Distributor, sigma=1., samples=2, inverter=inverter)
Distributor=Distributor, sigma=1., samples=2, iteration_controller=IC)
power0_energy = minimizer(power0_energy)[0]
......
......@@ -52,14 +52,13 @@ if __name__ == "__main__":
LS = ift.LineSearchStrongWolfe(c2=0.02)
minimizer = ift.RelaxedNewton(IC1, line_searcher=LS)
ICI = ift.GradientNormController(iteration_limit=2000,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
IC = ift.GradientNormController(iteration_limit=2000,
tol_abs_gradnorm=1e-3)
# initial guess
m = ift.full(h_space, 1e-7)
map_energy = ift.library.NonlinearWienerFilterEnergy(
m, d, R, nonlinearity, HT, power, N, S, inverter=inverter)
m, d, R, nonlinearity, HT, power, N, S, IC)
# Minimization with chosen minimizer
map_energy, convergence = minimizer(map_energy)
......
......@@ -76,10 +76,9 @@ if __name__ == "__main__":
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=1e2)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter, sampling_inverter=sampling_inverter)
S=S, N=N, R=R, iteration_controller=ctrl,
iteration_controller_sampling=sampling_ctrl)
m_k = wiener_curvature.inverse_times(j)
m = ht(m_k)
......
......@@ -50,10 +50,9 @@ if __name__ == "__main__":
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=1e-2)
sampling_ctrl = ift.GradientNormController(name="sampling",
tol_abs_gradnorm=2e1)
inverter = ift.ConjugateGradient(controller=ctrl)
sampling_inverter = ift.ConjugateGradient(controller=sampling_ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter, sampling_inverter=sampling_inverter)
S=S, N=N, R=R, iteration_controller=ctrl,
iteration_controller_sampling=sampling_ctrl)
m_k = wiener_curvature.inverse_times(j)
m = ht(m_k)
......
......@@ -81,9 +81,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h)
D = ift.InversionEnabler(D, IC, approximation=Phi_h)
m = HT(D(j))
# Uncertainty
......@@ -116,8 +115,7 @@ if __name__ == "__main__":
# initial guess
psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
inverter)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h, IC)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
tol_abs_gradnorm=1e-4)
minimizer = ift.RelaxedNewton(IC1)
......
......@@ -39,17 +39,17 @@ N_iter = 100
tol = 1e-3
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter,
sampling_inverter=inverter)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p,
iteration_controller=IC,
iteration_controller_sampling=IC)
m_xi = curv.inverse_times(j)
samps_long = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
tol = 1e2
IC = ift.GradientNormController(tol_abs_gradnorm=tol, iteration_limit=N_iter)
inverter = ift.ConjugateGradient(IC)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p, inverter=inverter,
sampling_inverter=inverter)
curv = ift.library.WienerFilterCurvature(S=S, N=N, R=R_p,
iteration_controller=IC,
iteration_controller_sampling=IC)
samps_short = [curv.draw_sample(from_inverse=True) for i in range(N_samps)]
# Compute mean
......
......@@ -38,8 +38,8 @@ if __name__ == "__main__":
j = Rh.adjoint_times(N.inverse_times(d))
ctrl = ift.GradientNormController(name="Iter", tol_abs_gradnorm=1e-10,
iteration_limit=300)
inverter = ift.ConjugateGradient(controller=ctrl)
Di = ift.library.WienerFilterCurvature(S=S, R=Rh, N=N, inverter=inverter)
Di = ift.library.WienerFilterCurvature(S=S, R=Rh, N=N,
iteration_controller=ctrl)
mh = Di.inverse_times(j)
m = ht(mh)
......
......@@ -98,10 +98,9 @@ if __name__ == "__main__":
IC = ift.GradientNormController(name="inverter", iteration_limit=1000,
tol_abs_gradnorm=0.0001)
inverter = ift.ConjugateGradient(controller=IC)
# setting up measurement precision matrix M
M = (ift.SandwichOperator.make(R.adjoint, Sh) + N)
M = ift.InversionEnabler(M, inverter)
M = ift.InversionEnabler(M, IC)
m = Sh(R.adjoint(M.inverse_times(d)))
# Plotting
......
......@@ -52,9 +52,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(d))
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Sh.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Sh)
D = ift.InversionEnabler(D, IC, approximation=Sh)
m = D(j)
# Plotting
......
......@@ -78,9 +78,8 @@ if __name__ == "__main__":
j = R.adjoint_times(N.inverse_times(data))
ctrl = ift.GradientNormController(
name="inverter", tol_abs_gradnorm=1e-5/(nu.K*(nu.m**dimensionality)))
inverter = ift.ConjugateGradient(controller=ctrl)
wiener_curvature = ift.library.WienerFilterCurvature(
S=S, N=N, R=R, inverter=inverter)
wiener_curvature = ift.library.WienerFilterCurvature(S=S, N=N, R=R,
iteration_controller=ctrl)
m = wiener_curvature.inverse_times(j)
m_s = HT(m)
......
......@@ -47,15 +47,14 @@ if __name__ == "__main__":
# Choose minimization strategy
ctrl = ift.GradientNormController(name="inverter", tol_abs_gradnorm=0.1)
inverter = ift.ConjugateGradient(controller=ctrl)
controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
minimizer = ift.RelaxedNewton(controller=controller)
m0 = ift.full(h_space, 0.)
# Initialize Wiener filter energy
energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
inverter=inverter,
sampling_inverter=inverter)
iteration_controller=ctrl,
iteration_controller_sampling=ctrl)
energy, convergence = minimizer(energy)
m = energy.position
......
......@@ -63,7 +63,7 @@ class NonlinearPowerEnergy(Energy):
# MR FIXME: docstring incomplete and outdated
def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity,
Distributor, sigma=0., samples=3, xi_sample_list=None,
inverter=None):
iteration_controller=None):
super(NonlinearPowerEnergy, self).__init__(position)
self.xi = xi
self.D = D
......@@ -83,7 +83,7 @@ class NonlinearPowerEnergy(Energy):
xi_sample_list = [D.draw_sample(from_inverse=True) + xi
for _ in range(samples)]
self.xi_sample_list = xi_sample_list
self.inverter = inverter
self._ic = iteration_controller
A = Distributor(exp(.5 * position))
......@@ -118,7 +118,7 @@ class NonlinearPowerEnergy(Energy):
self.Distributor, sigma=self.sigma,
samples=len(self.xi_sample_list),
xi_sample_list=self.xi_sample_list,
inverter=self.inverter)
iteration_controller=self._ic)
@property
def value(self):
......@@ -139,4 +139,4 @@ class NonlinearPowerEnergy(Energy):
op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse
result = op if result is None else result + op
result = result*(1./len(self.xi_sample_list)) + self.T
return InversionEnabler(result, self.inverter)
return InversionEnabler(result, self._ic)
......@@ -24,8 +24,7 @@ from ..sugar import makeOp
class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S,
inverter=None,
sampling_inverter=None):
iteration_controller=None, iteration_controller_sampling=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d.lock()
self.Instrument = Instrument
......@@ -37,10 +36,10 @@ class NonlinearWienerFilterEnergy(Energy):
residual = d - Instrument(nonlinearity(m))
self.N = N
self.S = S
self.inverter = inverter
if sampling_inverter is None:
sampling_inverter = inverter
self.sampling_inverter = sampling_inverter
self._ic = iteration_controller
if iteration_controller_sampling is None:
iteration_controller_sampling = self._ic
self._ic_samp = iteration_controller_sampling
t1 = S.inverse_times(position)
t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
......@@ -51,7 +50,7 @@ class NonlinearWienerFilterEnergy(Energy):
def at(self, position):
return self.__class__(position, self.d, self.Instrument,
self.nonlinearity, self.ht, self.power, self.N,
self.S, self.inverter)
self.S, self._ic, self._ic_samp)
@property
def value(self):
......@@ -64,5 +63,5 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def curvature(self):
return WienerFilterCurvature(self.R, self.N, self.S, self.inverter,
self.sampling_inverter)
return WienerFilterCurvature(self.R, self.N, self.S, self._ic,
self._ic_samp)
......@@ -25,9 +25,9 @@ from ..sugar import log
class PoissonEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, Phi_h,
inverter=None):
iteration_controller=None):
super(PoissonEnergy, self).__init__(position=position)
self._inverter = inverter
self._ic = iteration_controller
self._d = d
self._Instrument = Instrument
self._nonlinearity = nonlinearity
......@@ -51,7 +51,7 @@ class PoissonEnergy(Energy):
def at(self, position):
return self.__class__(position, self._d, self._Instrument,
self._nonlinearity, self._ht, self._Phi_h,
self._inverter)
self._ic)
@property
def value(self):
......@@ -63,5 +63,4 @@ class PoissonEnergy(Energy):
@property
def curvature(self):
return InversionEnabler(self._curv, self._inverter,
approximation=self._Phi_h.inverse)
return InversionEnabler(self._curv, self._ic, self._Phi_h.inverse)
......@@ -21,7 +21,8 @@ from ..operators.inversion_enabler import InversionEnabler
from ..operators.sampling_enabler import SamplingEnabler
def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None):
def WienerFilterCurvature(R, N, S, iteration_controller=None,
iteration_controller_sampling=None):
"""The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the
......@@ -37,16 +38,16 @@ def WienerFilterCurvature(R, N, S, inverter, sampling_inverter=None):
The noise covariance.
S : DiagonalOperator
The prior signal covariance
inverter : Minimizer
The minimizer to use during numerical inversion
sampling_inverter : Minimizer
The minimizer to use during numerical sampling
if None, it is not possible to draw inverse samples
default: None
iteration_controller : IterationController
The iteration controller to use during numerical inversion via
ConjugateGradient.
iteration_controller_sampling : IterationController
The iteration controller to use for sampling.
"""
M = SandwichOperator.make(R, N.inverse)
if sampling_inverter is not None:
op = SamplingEnabler(M, S.inverse, sampling_inverter, S.inverse)
if iteration_controller is not None:
op = SamplingEnabler(M, S.inverse, iteration_controller_sampling,
S.inverse)
else:
op = M + S.inverse
return InversionEnabler(op, inverter, S.inverse)
return InversionEnabler(op, iteration_controller, S.inverse)
......@@ -20,8 +20,8 @@ from ..minimization.quadratic_energy import QuadraticEnergy
from .wiener_filter_curvature import WienerFilterCurvature
def WienerFilterEnergy(position, d, R, N, S, inverter=None,
sampling_inverter=None):
def WienerFilterEnergy(position, d, R, N, S, iteration_controller=None,
iteration_controller_sampling=None):
"""The Energy for the Wiener filter.
It covers the case of linear measurement with
......@@ -48,6 +48,7 @@ def WienerFilterEnergy(position, d, R, N, S, inverter=None,
if None, it is not possible to draw inverse samples
default: None
"""
op = WienerFilterCurvature(R, N, S, inverter, sampling_inverter)
op = WienerFilterCurvature(R, N, S, iteration_controller,
iteration_controller_sampling)
vec = R.adjoint_times(N.inverse_times(d))
return QuadraticEnergy(position, op, vec)
......@@ -61,6 +61,4 @@ class EnergySum(Energy):
if precon is None and self._precon_idx is not None:
precon = self._energies[self._precon_idx].curvature
from ..operators.inversion_enabler import InversionEnabler
from .conjugate_gradient import ConjugateGradient
return InversionEnabler(
res, ConjugateGradient(self._min_controller), precon)
return InversionEnabler(res, self._min_controller, precon)
......@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..minimization.quadratic_energy import QuadraticEnergy
from ..minimization.iteration_controller import IterationController
import numpy as np
from ..logger import logger
from ..minimization.conjugate_gradient import ConjugateGradient
from ..minimization.iteration_controller import IterationController
from ..minimization.quadratic_energy import QuadraticEnergy
from .endomorphic_operator import EndomorphicOperator
import numpy as np
class InversionEnabler(EndomorphicOperator):
......@@ -34,9 +36,9 @@ class InversionEnabler(EndomorphicOperator):
The InversionEnabler object will support the same operation modes as
`op`, and additionally the inverse set. The newly-added modes will
be computed by iterative inversion.
inverter : :class:`Minimizer`
The minimizer to use for the iterative numerical inversion.
Typically, this is a :class:`ConjugateGradient` object.
iteration_controller : :class:`IterationController`
The iteration controller to use for the iterative numerical inversion
done by a :class:`ConjugateGradient` object.
approximation : :class:`LinearOperator`, optional
if not None, this operator should be an approximation to `op`, which
supports the operation modes that `op` doesn't have. It is used as a
......@@ -44,10 +46,10 @@ class InversionEnabler(EndomorphicOperator):
convergence.
"""
def __init__(self, op, inverter, approximation=None):
def __init__(self, op, iteration_controller, approximation=None):
super(InversionEnabler, self).__init__()
self._op = op
self._inverter = inverter
self._ic = iteration_controller
self._approximation = approximation
@property
......@@ -70,7 +72,8 @@ class InversionEnabler(EndomorphicOperator):
if prec is not None:
prec = prec._flip_modes(self._ilog[mode])
energy = QuadraticEnergy(x0, invop, x)
r, stat = self._inverter(energy, preconditioner=prec)
inverter = ConjugateGradient(self._ic)
r, stat = inverter(energy, preconditioner=prec)
if stat != IterationController.CONVERGED:
logger.warning("Error detected during operator inversion")
return r.position
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment