Commit 22c7de10 authored by Philipp Arras's avatar Philipp Arras
Browse files

WienerFilterCurvature: Interface change

This interface change is necessary since the sampling dtype cannot
easily be set by a single keyword for complicated operators.
parent 44ee2914
Pipeline #107724 passed with stages
in 21 minutes and 26 seconds
Changes since NIFTy 7
=====================
WienerFilterCurvature interface change
--------------------------------------
`ift.WienerFilterCurvature` does not expect sampling dtypes for the likelihood
and the prior anymore. These have to be set with an `ift.SamplingDtypeSetter`
beforehand.
Minisanity
----------
......
%% Cell type:markdown id: tags:
# Code example: Wiener filter
%% Cell type:markdown id: tags:
## Introduction
IFT starting point:
$$d = Rs+n$$
Typically, $s$ is a continuous field, $d$ a discrete data vector. Particularly, $R$ is not invertible.
IFT aims at **inverting** the above uninvertible problem in the **best possible way** using Bayesian statistics.
NIFTy (Numerical Information Field Theory) is a Python framework in which IFT problems can be tackled easily.
Main Interfaces:
- **Spaces**: Cartesian, 2-Spheres (Healpix, Gauss-Legendre) and their respective harmonic spaces.
- **Fields**: Defined on spaces.
- **Operators**: Acting on fields.
%% Cell type:markdown id: tags:
## Wiener filter on one-dimensional fields
### Assumptions
- $d=Rs+n$, $R$ linear operator.
- $\mathcal P (s) = \mathcal G (s,S)$, $\mathcal P (n) = \mathcal G (n,N)$ where $S, N$ are positive definite matrices.
### Posterior
The Posterior is given by:
$$\mathcal P (s|d) \propto P(s,d) = \mathcal G(d-Rs,N) \,\mathcal G(s,S) \propto \mathcal G (s-m,D) $$
where
$$m = Dj$$
with
$$D = (S^{-1} +R^\dagger N^{-1} R)^{-1} , \quad j = R^\dagger N^{-1} d.$$
Let us implement this in NIFTy!
%% Cell type:markdown id: tags:
### In NIFTy
- We assume statistical homogeneity and isotropy. Therefore the signal covariance $S$ is diagonal in harmonic space, and is described by a one-dimensional power spectrum, assumed here as $$P(k) = P_0\,\left(1+\left(\frac{k}{k_0}\right)^2\right)^{-\gamma /2},$$
with $P_0 = 0.2, k_0 = 5, \gamma = 4$.
- $N = 0.2 \cdot \mathbb{1}$.
- Number of data points $N_{pix} = 512$.
- reconstruction in harmonic space.
- Response operator:
$$R = FFT_{\text{harmonic} \rightarrow \text{position}}$$
%% Cell type:code id: tags:
``` python
N_pixels = 512 # Number of pixels
def pow_spec(k):
P0, k0, gamma = [.2, 5, 4]
return P0 / ((1. + (k/k0)**2)**(gamma / 2))
```
%% Cell type:markdown id: tags:
### Implementation
%% Cell type:markdown id: tags:
#### Import Modules
%% Cell type:code id: tags:
``` python
%matplotlib inline
import numpy as np
import nifty8 as ift
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 100
plt.style.use("seaborn-notebook")
```
%% Cell type:markdown id: tags:
#### Implement Propagator
%% Cell type:code id: tags:
``` python
def Curvature(R, N, Sh):
IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1)
N = ift.SamplingDtypeSetter(N, np.float64)
Sh = ift.SamplingDtypeSetter(Sh, np.float64)
# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# helper methods.
return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)
```
%% Cell type:markdown id: tags:
#### Conjugate Gradient Preconditioning
- $D$ is defined via:
$$D^{-1} = \mathcal S_h^{-1} + R^\dagger N^{-1} R.$$
In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: *Conjugate Gradient*).
<!--
- One can define the *condition number* of a non-singular and normal matrix $A$:
$$\kappa (A) := \frac{|\lambda_{\text{max}}|}{|\lambda_{\text{min}}|},$$
where $\lambda_{\text{max}}$ and $\lambda_{\text{min}}$ are the largest and smallest eigenvalue of $A$, respectively.
- The larger $\kappa$ the slower Conjugate Gradient.
- By default, conjugate gradient solves: $D^{-1} m = j$ for $m$, where $D^{-1}$ can be badly conditioned. If one knows a non-singular matrix $T$ for which $TD^{-1}$ is better conditioned, one can solve the equivalent problem:
$$\tilde A m = \tilde j,$$
where $\tilde A = T D^{-1}$ and $\tilde j = Tj$.
- In our case $S^{-1}$ is responsible for the bad conditioning of $D$ depending on the chosen power spectrum. Thus, we choose
$$T = \mathcal F^\dagger S_h^{-1} \mathcal F.$$
-->
%% Cell type:markdown id: tags:
#### Generate Mock data
- Generate a field $s$ and $n$ with given covariances.
- Calculate $d$.
%% Cell type:code id: tags:
``` python
s_space = ift.RGSpace(N_pixels)
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space, target=s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
R = HT # @ ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)
# Fields and data
sh = Sh.draw_sample_with_dtype(dtype=np.float64)
noiseless_data=R(sh)
noise_amplitude = np.sqrt(0.2)
N = ift.ScalingOperator(s_space, noise_amplitude**2)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
d = noiseless_data + n
j = R.adjoint_times(N.inverse_times(d))
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
```
%% Cell type:markdown id: tags:
#### Run Wiener Filter
%% Cell type:code id: tags:
``` python
m = D(j)
```
%% Cell type:markdown id: tags:
#### Results
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = HT(sh).val
m_data = HT(m).val
d_data = d.val
plt.plot(s_data, 'r', label="Signal", linewidth=2)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction",linewidth=2)
plt.title("Reconstruction")
plt.legend()
plt.show()
```
%% Cell type:code id: tags:
``` python
plt.plot(s_data - s_data, 'r', label="Signal", linewidth=2)
plt.plot(d_data - s_data, 'k.', label="Data")
plt.plot(m_data - s_data, 'k', label="Reconstruction",linewidth=2)
plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)
plt.title("Residuals")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
#### Power Spectrum
%% Cell type:code id: tags:
``` python
s_power_data = ift.power_analyze(sh).val
m_power_data = ift.power_analyze(m).val
plt.loglog()
plt.xlim(1, int(N_pixels/2))
ymin = min(m_power_data)
plt.ylim(ymin, 1)
xs = np.arange(1,int(N_pixels/2),.1)
plt.plot(xs, pow_spec(xs), label="True Power Spectrum", color='k',alpha=0.5)
plt.plot(s_power_data, 'r', label="Signal")
plt.plot(m_power_data, 'k', label="Reconstruction")
plt.axhline(noise_amplitude**2 / N_pixels, color="k", linestyle='--', label="Noise level", alpha=.5)
plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)
plt.title("Power Spectrum")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
## Wiener Filter on Incomplete Data
%% Cell type:code id: tags:
``` python
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(s_space, noise_amplitude**2)
# R is defined below
# Fields
sh = Sh.draw_sample_with_dtype(dtype=np.float64)
s = HT(sh)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
```
%% Cell type:markdown id: tags:
### Partially Lose Data
%% Cell type:code id: tags:
``` python
l = int(N_pixels * 0.2)
h = int(N_pixels * 0.2 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h] = 0
mask = ift.Field.from_raw(s_space, mask)
R = ift.DiagonalOperator(mask)(HT)
n = n.val_rw()
n[l:h] = 0
n = ift.Field.from_raw(s_space, n)
d = R(sh) + n
```
%% Cell type:code id: tags:
``` python
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
j = R.adjoint_times(N.inverse_times(d))
m = D(j)
```
%% Cell type:markdown id: tags:
### Compute Uncertainty
%% Cell type:code id: tags:
``` python
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200, np.float64)
```
%% Cell type:markdown id: tags:
### Get data
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = s.val
m_data = HT(m).val
m_var_data = m_var.val
uncertainty = np.sqrt(m_var_data)
d_data = d.val_rw()
# Set lost data to NaN for proper plotting
d_data[d_data == 0] = np.nan
```
%% Cell type:code id: tags:
``` python
plt.axvspan(l, h, facecolor='0.8',alpha=0.5)
plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0.5', alpha=0.5)
plt.plot(s_data, 'r', label="Signal", alpha=1, linewidth=2)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction", linewidth=2)
plt.title("Reconstruction of incomplete data")
plt.legend()
```
%% Cell type:markdown id: tags:
## Wiener filter on two-dimensional field
%% Cell type:code id: tags:
``` python
N_pixels = 256 # Number of pixels
sigma2 = 2. # Noise variance
def pow_spec(k):
P0, k0, gamma = [.2, 2, 4]
return P0 * (1. + (k/k0)**2)**(-gamma/2)
s_space = ift.RGSpace([N_pixels, N_pixels])
```
%% Cell type:code id: tags:
``` python
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space,s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(s_space, sigma2)
# Fields and data
sh = Sh.draw_sample_with_dtype(dtype=np.float64)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
# Lose some data
l = int(N_pixels * 0.33)
h = int(N_pixels * 0.33 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h,l:h] = 0.
mask = ift.Field.from_raw(s_space, mask)
R = ift.DiagonalOperator(mask)(HT)
n = n.val_rw()
n[l:h, l:h] = 0
n = ift.Field.from_raw(s_space, n)
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
d = R(sh) + n
j = R.adjoint_times(N.inverse_times(d))
# Run Wiener filter
m = D(j)
# Uncertainty
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20, np.float64)
# Get data
s_data = HT(sh).val
m_data = HT(m).val
m_var_data = m_var.val
d_data = d.val
uncertainty = np.sqrt(np.abs(m_var_data))
```
%% Cell type:code id: tags:
``` python
cmap = ['magma', 'inferno', 'plasma', 'viridis'][1]
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(1, 2)
data = [s_data, d_data]
caption = ["Signal", "Data"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cmap, vmin=mi,
vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:code id: tags:
``` python
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(3, 2, figsize=(10, 15))
sample = HT(curv.draw_sample(from_inverse=True)+m).val
post_mean = (m_mean + HT(m)).val
data = [s_data, m_data, post_mean, sample, s_data - m_data, uncertainty]
caption = ["Signal", "Reconstruction", "Posterior mean", "Sample", "Residuals", "Uncertainty Map"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cmap, vmin=mi, vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:markdown id: tags:
### Is the uncertainty map reliable?
%% Cell type:code id: tags:
``` python
precise = (np.abs(s_data-m_data) < uncertainty)
print("Error within uncertainty map bounds: " + str(np.sum(precise) * 100 / N_pixels**2) + "%")
plt.imshow(precise.astype(float), cmap="brg")
plt.colorbar()
```
......
......@@ -11,21 +11,18 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.inversion_enabler import InversionEnabler
from ..operators.sampling_enabler import SamplingDtypeSetter, SamplingEnabler
from ..operators.sampling_enabler import SamplingEnabler
from ..operators.sandwich_operator import SandwichOperator
def WienerFilterCurvature(R, N, S, iteration_controller=None,
iteration_controller_sampling=None,
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64):
iteration_controller_sampling=None):
"""The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the
......@@ -39,32 +36,27 @@ def WienerFilterCurvature(R, N, S, iteration_controller=None,
The response operator of the Wiener filter measurement.
N : EndomorphicOperator
The noise covariance.
S : DiagonalOperator
The prior signal covariance
S : EndomorphicOperator
The prior signal covariance.
iteration_controller : IterationController
The iteration controller to use during numerical inversion via
ConjugateGradient.
iteration_controller_sampling : IterationController
The iteration controller to use for sampling.
data_sampling_dtype : numpy.dtype or dict of numpy.dtype
Data type used for sampling from likelihood. Conincides with the data
type of the data used in the inference problem. Default is float64.
prior_sampling_dtype : numpy.dtype or dict of numpy.dtype
Data type used for sampling from likelihood. Coincides with the data
type of the parameters of the forward model used for the inference
problem. Default is float64.
Note
----
If samples shall be drawn from this operator, `N` and `S` have to implement
`draw_sample()`.
"""
Ninv = N.inverse
if not isinstance(N, EndomorphicOperator):
raise TypeError
if not isinstance(S, EndomorphicOperator):
raise TypeError
M = SandwichOperator.make(R, N.inverse)
Sinv = S.inverse
if data_sampling_dtype is not None:
Ninv = SamplingDtypeSetter(Ninv, data_sampling_dtype)
if prior_sampling_dtype is not None:
Sinv = SamplingDtypeSetter(Sinv, data_sampling_dtype)
M = SandwichOperator.make(R, Ninv)
if iteration_controller_sampling is not None:
op = SamplingEnabler(M, Sinv, iteration_controller_sampling,
Sinv)
op = SamplingEnabler(M, Sinv, iteration_controller_sampling, Sinv)
else:
op = M + Sinv
op = InversionEnabler(op, iteration_controller, Sinv)
return op
return InversionEnabler(op, iteration_controller, Sinv)
......@@ -79,16 +79,14 @@ def test_WF_curvature(space):
required_result = ift.full(space, 1.)
s = ift.Field.from_random(domain=space, random_type='uniform') + 0.5
S = ift.DiagonalOperator(s)
S = ift.SamplingDtypeSetter(ift.DiagonalOperator(s), np.float64)
r = ift.Field.from_random(domain=space, random_type='uniform')
R = ift.DiagonalOperator(r)
n = ift.Field.from_random(domain=space, random_type='uniform') + 0.5
N = ift.DiagonalOperator(n)
N = ift.SamplingDtypeSetter(ift.DiagonalOperator(n), np.float64)
all_diag = 1./s + r**2/n
curv = ift.WienerFilterCurvature(R, N, S, iteration_controller=IC,
iteration_controller_sampling=IC,
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64)
iteration_controller_sampling=IC)
m = curv.inverse(required_result)
assert_allclose(
m.val,
......@@ -101,13 +99,11 @@ def test_WF_curvature(space):
if len(space.shape) == 1:
R = ift.ValueInserter(space, [0])
n = ift.from_random(R.domain, 'uniform') + 0.5
N = ift.DiagonalOperator(n)
N = ift.SamplingDtypeSetter(ift.DiagonalOperator(n), np.float64)
all_diag = 1./s + R(1/n)
curv = ift.WienerFilterCurvature(R.adjoint, N, S,
iteration_controller=IC,
iteration_controller_sampling=IC,
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64)
iteration_controller_sampling=IC)
m = curv.inverse(required_result)
assert_allclose(
m.val,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment