Commit c3ed466f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

no more chains

parent 369c6e7c
......@@ -429,7 +429,7 @@
"mask[l:h] = 0\n",
"mask = ift.Field.from_global_data(s_space, mask)\n",
"\n",
"R = ift.DiagonalOperator(mask).chain(HT)\n",
"R = ift.DiagonalOperator(mask)(HT)\n",
"n = n.to_global_data_rw()\n",
"n[l:h] = 0\n",
"n = ift.Field.from_global_data(s_space, n)\n",
......@@ -585,7 +585,7 @@
"mask[l:h,l:h] = 0.\n",
"mask = ift.Field.from_global_data(s_space, mask)\n",
"\n",
"R = ift.DiagonalOperator(mask).chain(HT)\n",
"R = ift.DiagonalOperator(mask)(HT)\n",
"n = n.to_global_data_rw()\n",
"n[l:h, l:h] = 0\n",
"n = ift.Field.from_global_data(s_space, n)\n",
......
%% Cell type:markdown id: tags:
# A NIFTy demonstration
%% Cell type:markdown id: tags:
## IFT: Big Picture
IFT starting point:
$$d = Rs+n$$
Typically, $s$ is a continuous field, $d$ a discrete data vector. Particularly, $R$ is not invertible.
IFT aims at **inverting** the above uninvertible problem in the **best possible way** using Bayesian statistics.
## NIFTy
NIFTy (Numerical Information Field Theory) is a Python framework in which IFT problems can be tackled easily.
Main Interfaces:
- **Spaces**: Cartesian, 2-Spheres (Healpix, Gauss-Legendre) and their respective harmonic spaces.
- **Fields**: Defined on spaces.
- **Operators**: Acting on fields.
%% Cell type:markdown id: tags:
## Wiener Filter: Formulae
### Assumptions
- $d=Rs+n$, $R$ linear operator.
- $\mathcal P (s) = \mathcal G (s,S)$, $\mathcal P (n) = \mathcal G (n,N)$ where $S, N$ are positive definite matrices.
### Posterior
The Posterior is given by:
$$\mathcal P (s|d) \propto P(s,d) = \mathcal G(d-Rs,N) \,\mathcal G(s,S) \propto \mathcal G (s-m,D) $$
where
$$\begin{align}
m &= Dj \\
D^{-1}&= (S^{-1} +R^\dagger N^{-1} R )\\
j &= R^\dagger N^{-1} d
\end{align}$$
Let us implement this in NIFTy!
%% Cell type:markdown id: tags:
## Wiener Filter: Example
- We assume statistical homogeneity and isotropy. Therefore the signal covariance $S$ is diagonal in harmonic space, and is described by a one-dimensional power spectrum, assumed here as $$P(k) = P_0\,\left(1+\left(\frac{k}{k_0}\right)^2\right)^{-\gamma /2},$$
with $P_0 = 0.2, k_0 = 5, \gamma = 4$.
- $N = 0.2 \cdot \mathbb{1}$.
- Number of data points $N_{pix} = 512$.
- reconstruction in harmonic space.
- Response operator:
$$R = FFT_{\text{harmonic} \rightarrow \text{position}}$$
%% Cell type:code id: tags:
``` python
N_pixels = 512 # Number of pixels
def pow_spec(k):
P0, k0, gamma = [.2, 5, 4]
return P0 / ((1. + (k/k0)**2)**(gamma / 2))
```
%% Cell type:markdown id: tags:
## Wiener Filter: Implementation
%% Cell type:markdown id: tags:
### Import Modules
%% Cell type:code id: tags:
``` python
import numpy as np
np.random.seed(40)
import nifty5 as ift
import matplotlib.pyplot as plt
%matplotlib inline
```
%% Cell type:markdown id: tags:
### Implement Propagator
%% Cell type:code id: tags:
``` python
def Curvature(R, N, Sh):
IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1)
# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# helper methods.
return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)
```
%% Cell type:markdown id: tags:
### Conjugate Gradient Preconditioning
- $D$ is defined via:
$$D^{-1} = \mathcal S_h^{-1} + R^\dagger N^{-1} R.$$
In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: *Conjugate Gradient*).
<!--
- One can define the *condition number* of a non-singular and normal matrix $A$:
$$\kappa (A) := \frac{|\lambda_{\text{max}}|}{|\lambda_{\text{min}}|},$$
where $\lambda_{\text{max}}$ and $\lambda_{\text{min}}$ are the largest and smallest eigenvalue of $A$, respectively.
- The larger $\kappa$ the slower Conjugate Gradient.
- By default, conjugate gradient solves: $D^{-1} m = j$ for $m$, where $D^{-1}$ can be badly conditioned. If one knows a non-singular matrix $T$ for which $TD^{-1}$ is better conditioned, one can solve the equivalent problem:
$$\tilde A m = \tilde j,$$
where $\tilde A = T D^{-1}$ and $\tilde j = Tj$.
- In our case $S^{-1}$ is responsible for the bad conditioning of $D$ depending on the chosen power spectrum. Thus, we choose
$$T = \mathcal F^\dagger S_h^{-1} \mathcal F.$$
-->
%% Cell type:markdown id: tags:
### Generate Mock data
- Generate a field $s$ and $n$ with given covariances.
- Calculate $d$.
%% Cell type:code id: tags:
``` python
s_space = ift.RGSpace(N_pixels)
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space, target=s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)
# Fields and data
sh = Sh.draw_sample()
noiseless_data=R(sh)
noise_amplitude = np.sqrt(0.2)
N = ift.ScalingOperator(noise_amplitude**2, s_space)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
d = noiseless_data + n
j = R.adjoint_times(N.inverse_times(d))
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
```
%% Cell type:markdown id: tags:
### Run Wiener Filter
%% Cell type:code id: tags:
``` python
m = D(j)
```
%% Cell type:markdown id: tags:
### Signal Reconstruction
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = HT(sh).to_global_data()
m_data = HT(m).to_global_data()
d_data = d.to_global_data()
plt.figure(figsize=(15,10))
plt.plot(s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction",linewidth=3)
plt.title("Reconstruction")
plt.legend()
plt.show()
```
%% Cell type:code id: tags:
``` python
plt.figure(figsize=(15,10))
plt.plot(s_data - s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data - s_data, 'k.', label="Data")
plt.plot(m_data - s_data, 'k', label="Reconstruction",linewidth=3)
plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)
plt.title("Residuals")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
### Power Spectrum
%% Cell type:code id: tags:
``` python
s_power_data = ift.power_analyze(sh).to_global_data()
m_power_data = ift.power_analyze(m).to_global_data()
plt.figure(figsize=(15,10))
plt.loglog()
plt.xlim(1, int(N_pixels/2))
ymin = min(m_power_data)
plt.ylim(ymin, 1)
xs = np.arange(1,int(N_pixels/2),.1)
plt.plot(xs, pow_spec(xs), label="True Power Spectrum", color='k',alpha=0.5)
plt.plot(s_power_data, 'r', label="Signal")
plt.plot(m_power_data, 'k', label="Reconstruction")
plt.axhline(noise_amplitude**2 / N_pixels, color="k", linestyle='--', label="Noise level", alpha=.5)
plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)
plt.title("Power Spectrum")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
## Wiener Filter on Incomplete Data
%% Cell type:code id: tags:
``` python
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(noise_amplitude**2,s_space)
# R is defined below
# Fields
sh = Sh.draw_sample()
s = HT(sh)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
```
%% Cell type:markdown id: tags:
### Partially Lose Data
%% Cell type:code id: tags:
``` python
l = int(N_pixels * 0.2)
h = int(N_pixels * 0.2 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h] = 0
mask = ift.Field.from_global_data(s_space, mask)
R = ift.DiagonalOperator(mask).chain(HT)
R = ift.DiagonalOperator(mask)(HT)
n = n.to_global_data_rw()
n[l:h] = 0
n = ift.Field.from_global_data(s_space, n)
d = R(sh) + n
```
%% Cell type:code id: tags:
``` python
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
j = R.adjoint_times(N.inverse_times(d))
m = D(j)
```
%% Cell type:markdown id: tags:
### Compute Uncertainty
%% Cell type:code id: tags:
``` python
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200)
```
%% Cell type:markdown id: tags:
### Get data
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = s.to_global_data()
m_data = HT(m).to_global_data()
m_var_data = m_var.to_global_data()
uncertainty = np.sqrt(m_var_data)
d_data = d.to_global_data_rw()
# Set lost data to NaN for proper plotting
d_data[d_data == 0] = np.nan
```
%% Cell type:code id: tags:
``` python
fig = plt.figure(figsize=(15,10))
plt.axvspan(l, h, facecolor='0.8',alpha=0.5)
plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0.5', alpha=0.5)
plt.plot(s_data, 'r', label="Signal", alpha=1, linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction", linewidth=3)
plt.title("Reconstruction of incomplete data")
plt.legend()
```
%% Cell type:markdown id: tags:
# 2d Example
%% Cell type:code id: tags:
``` python
N_pixels = 256 # Number of pixels
sigma2 = 2. # Noise variance
def pow_spec(k):
P0, k0, gamma = [.2, 2, 4]
return P0 * (1. + (k/k0)**2)**(-gamma/2)
s_space = ift.RGSpace([N_pixels, N_pixels])
```
%% Cell type:code id: tags:
``` python
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space,s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(sigma2,s_space)
# Fields and data
sh = Sh.draw_sample()
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
# Lose some data
l = int(N_pixels * 0.33)
h = int(N_pixels * 0.33 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h,l:h] = 0.
mask = ift.Field.from_global_data(s_space, mask)
R = ift.DiagonalOperator(mask).chain(HT)
R = ift.DiagonalOperator(mask)(HT)
n = n.to_global_data_rw()
n[l:h, l:h] = 0
n = ift.Field.from_global_data(s_space, n)
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
d = R(sh) + n
j = R.adjoint_times(N.inverse_times(d))
# Run Wiener filter
m = D(j)
# Uncertainty
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20)
# Get data
s_data = HT(sh).to_global_data()
m_data = HT(m).to_global_data()
m_var_data = m_var.to_global_data()
d_data = d.to_global_data()
uncertainty = np.sqrt(np.abs(m_var_data))
```
%% Cell type:code id: tags:
``` python
cm = ['magma', 'inferno', 'plasma', 'viridis'][1]
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
data = [s_data, d_data]
caption = ["Signal", "Data"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi,
vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:code id: tags:
``` python
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(3, 2, figsize=(15, 22.5))
sample = HT(curv.draw_sample(from_inverse=True)+m).to_global_data()
post_mean = (m_mean + HT(m)).to_global_data()
data = [s_data, m_data, post_mean, sample, s_data - m_data, uncertainty]
caption = ["Signal", "Reconstruction", "Posterior mean", "Sample", "Residuals", "Uncertainty Map"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi, vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:markdown id: tags:
### Is the uncertainty map reliable?
%% Cell type:code id: tags:
``` python
precise = (np.abs(s_data-m_data) < uncertainty)
print("Error within uncertainty map bounds: " + str(np.sum(precise) * 100 / N_pixels**2) + "%")
plt.figure(figsize=(15,10))
plt.imshow(precise.astype(float), cmap="brg")
plt.colorbar()
```
%% Cell type:markdown id: tags:
# Start Coding
## NIFTy Repository + Installation guide
https://gitlab.mpcdf.mpg.de/ift/NIFTy
NIFTy v5 **more or less stable!**
......
......@@ -53,7 +53,7 @@ if __name__ == '__main__':
A = pd(a)
# Set up a sky model
sky = HT.chain(ift.makeOp(A)).positive_tanh()
sky = HT(ift.makeOp(A)).positive_tanh()
GR = ift.GeometryRemover(position_space)
# Set up instrumental response
......@@ -61,7 +61,7 @@ if __name__ == '__main__':
# Generate mock data
d_space = R.target[0]
p = R.chain(sky)
p = R(sky)
mock_position = ift.from_random('normal', harmonic_space)
pp = p(mock_position)
data = np.random.binomial(1, pp.to_global_data().astype(np.float64))
......
......@@ -78,7 +78,7 @@ if __name__ == '__main__':
GR = ift.GeometryRemover(position_space)
mask = ift.Field.from_global_data(position_space, mask)
Mask = ift.DiagonalOperator(mask)
R = GR.chain(Mask).chain(HT)
R = GR(Mask(HT))
data_space = GR.target
......@@ -93,7 +93,7 @@ if __name__ == '__main__':
# Build propagator D and information source j
j = R.adjoint_times(N.inverse_times(data))
D_inv = R.adjoint.chain(N.inverse).chain(R) + S.inverse
D_inv = R.adjoint(N.inverse(R)) + S.inverse
# Make it invertible
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse
......@@ -112,7 +112,7 @@ if __name__ == '__main__':
title="getting_started_1")
else:
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal')
ift.plot(mask_to_nan(mask, (GR.chain(Mask)).adjoint(data)),
ift.plot(mask_to_nan(mask, (GR(Mask)).adjoint(data)),
title='Data')
ift.plot(HT(m), title='Reconstruction')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals')
......
......@@ -70,16 +70,16 @@ if __name__ == '__main__':
A = pd(a)
# Set up a sky model
sky = ift.exp(HT.chain(ift.makeOp(A)))
sky = ift.exp(HT(ift.makeOp(A)))
M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space)
# Set up instrumental response
R = GR.chain(M)
R = GR(M)
# Generate mock data
d_space = R.target[0]
lamb = R.chain(sky)
lamb = R(sky)
mock_position = ift.from_random('normal', domain)
data = lamb(mock_position)
data = np.random.poisson(data.to_global_data().astype(np.float64))
......
......@@ -44,8 +44,8 @@ if __name__ == '__main__':
domain = ift.MultiDomain.union(
(A.domain, ift.MultiDomain.make({'xi': harmonic_space})))
correlated_field = ht.chain(
power_distributor.chain(A)*ift.FieldAdapter(domain, "xi"))
correlated_field = ht(
power_distributor(A)*ift.FieldAdapter(domain, "xi"))
# alternatively to the block above one can do:
# correlated_field = ift.CorrelatedField(position_space, A)
......@@ -57,7 +57,7 @@ if __name__ == '__main__':
R = ift.LOSResponse(position_space, starts=LOS_starts,
ends=LOS_ends)
# build signal response model and model likelihood
signal_response = R.chain(signal)
signal_response = R(signal)
# specify noise
data_space = R.target
noise = .001
......@@ -69,7 +69,7 @@ if __name__ == '__main__':
# set up model likelihood
likelihood = ift.GaussianEnergy(
mean=data, covariance=N).chain(signal_response)
mean=data, covariance=N)(signal_response)
# set up minimization and inversion schemes
ic_cg = ift.GradientNormController(iteration_limit=10)
......
......@@ -97,7 +97,7 @@ d = ift.from_global_data(d_space, y)
N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
IC = ift.GradientNormController(tol_abs_gradnorm=1e-8)
likelihood = ift.GaussianEnergy(d, N).chain(R)
likelihood = ift.GaussianEnergy(d, N)(R)
H = ift.Hamiltonian(likelihood, IC)
H = ift.EnergyAdapter(params, H)
H = H.make_invertible(IC)
......
......@@ -40,7 +40,7 @@ class Hamiltonian(Operator):
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
def apply(self, x):
if self._ic_samp is None or not isinstance(x, Linearization):
return self._lh(x) + self._prior(x)
else:
......
......@@ -42,6 +42,6 @@ class SampledKullbachLeiblerDivergence(Operator):
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
def apply(self, x):
return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
(1./len(self._res_samples)))
......@@ -130,7 +130,7 @@ class AmplitudeModel(Operator):
cepstrum = create_cepstrum_amplitude_field(dof_space, kern)
ceps = makeOp(sqrt(cepstrum))
self._smooth_op = sym.chain(qht).chain(ceps)
self._smooth_op = sym(qht(ceps))
self._keys = tuple(keys)
@property
......@@ -141,7 +141,7 @@ class AmplitudeModel(Operator):
def target(self):
return self._target
def __call__(self, x):
def apply(self, x):
smooth_spec = self._smooth_op(x[self._keys[0]])
phi = x[self._keys[1]] + self._norm_phi_mean
linear_spec = self._slope(phi)
......
......@@ -39,7 +39,7 @@ class BernoulliEnergy(Operator):
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
def apply(self, x):
x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization):
......
......@@ -58,7 +58,7 @@ class CorrelatedField(Operator):
def target(self):
return self._ht.target
def __call__(self, x):
def apply(self, x):
A = self._power_distributor(self._amplitude_model(x))
correlated_field_h = A * x["xi"]
correlated_field = self._ht(correlated_field_h)
......
......@@ -55,7 +55,7 @@ class GaussianEnergy(Operator):
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
def apply(self, x):
residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual)
res = .5*residual.vdot(icovres)
......
......@@ -41,7 +41,7 @@ class PoissonianEnergy(Operator):
def target(self):
return DomainTuple.scalar_domain()
def __call__(self, x):
def apply(self, x):
x = self._op(x)
res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization):
......
......@@ -46,8 +46,8 @@ class Linearization(object):
def __neg__(self):
return Linearization(
-self._val, self._jac.chain(-1),
None if self._metric is None else self._metric.chain(-1))
-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
def __add__(self, other):
if isinstance(other, Linearization):
......@@ -77,24 +77,24 @@ class Linearization(object):
d2 = makeOp(other._val)
return Linearization(
self._val*other._val,
d2.chain(self._jac) + d1.chain(other._jac))
d2(self._jac) + d1(other._jac))
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
met = None if self._metric is None else self._metric.chain(other)
return Linearization(self._val*other, self._jac.chain(other), met)
met = None if self._metric is None else self._metric(other)
return Linearization(self._val*other, self._jac(other), met)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, d2.chain(self._jac))
return Linearization(self._val*other, d2(self._jac))
raise TypeError
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac.chain(other))
return Linearization(self._val*other, self._jac(other))
if isinstance(other, (Field, MultiField)):
d1 = makeOp(other)
return Linearization(self._val*other, d1.chain(self._jac))
return Linearization(self._val*other, d1(self._jac))
def vdot(self, other):
from .domain_tuple import DomainTuple
......@@ -102,11 +102,11 @@ class Linearization(object):
if isinstance(other, (Field, MultiField)):
return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
VdotOperator(other).chain(self._jac))
VdotOperator(other)(self._jac))
return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)),
VdotOperator(self._val).chain(other._jac) +
VdotOperator(other._val).chain(self._jac))
VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac))
def sum(self):
from .domain_tuple import DomainTuple
......@@ -114,24 +114,24 @@ class Linearization(object):
from .sugar import full
return Linearization(
Field(DomainTuple.scalar_domain(), self._val.sum()),
SumReductionOperator(self._jac.target).chain(self._jac))
SumReductionOperator(self._jac.target)(self._jac))
def exp(self):
tmp = self._val.exp()
return Linearization(tmp, makeOp(tmp).chain(self._jac))
return Linearization(tmp, makeOp(tmp)(self._jac))
def log(self):
tmp = self._val.log()
return Linearization(tmp, makeOp(1./self._val).chain(self._jac))
return Linearization(tmp, makeOp(1./self._val)(self._jac))
def tanh(self):
tmp = self._val.tanh()
return Linearization(tmp, makeOp(1.-tmp**2).chain(self._jac))
return Linearization(tmp, makeOp(1.-tmp**2)(self._jac))
def positive_tanh(self):
tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp)
return Linearization(tmp2, makeOp(0.5*(1.-tmp**2)).chain(self._jac))
return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def add_metric(self, metric):
return Linearization(self._val, self._jac, metric)
......
......@@ -68,7 +68,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op):
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = tuple(v1.chain(v2) for v1, v2 in zip(self._ops, op._ops))
res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
......
......@@ -67,4 +67,4 @@ def HarmonicSmoothingOperator(domain, sigma, space=None):
ddom = list(domain)
ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space)
return Hartley.inverse.chain(diag).chain(Hartley)
return Hartley.inverse(diag(Hartley))
......@@ -142,9 +142,6 @@ class LinearOperator(Operator):
from .chain_operator import ChainOperator
return ChainOperator.make([self, other2])
def chain(self, other):
return self.__matmul__(other)
def __rmatmul__(self, other):
if np.isscalar(other) and other == 1.:
return self
......@@ -213,10 +210,14 @@ class LinearOperator(Operator):
def __call__(self, x):
"""Same as :meth:`times`"""
from ..field import Field
from ..multi.multi_field import MultiField
if isinstance(x, (Field, MultiField)):
return self.apply(x, self.TIMES)
from ..linearization import Linearization
if isinstance(x, Linearization):
return Linearization(self(x._val), self.chain(x._jac))
return self.apply(x, self.TIMES)
return Linearization(self(x._val), self(x._jac))
return self.__matmul__(x)
def times(self, x):
""" Applies the Operator to a given Field.
......
......@@ -33,26 +33,13 @@ class Operator(NiftyMetaBase()):
return NotImplemented
return _OpProd.make((self, x))
def chain(self, x):
res = self.__matmul__(x)
if res == NotImplemented:
raise TypeError("operator expected")
return res
def apply(self, x):
raise NotImplementedError
def __call__(self, x):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise NotImplementedError
if isinstance(x, Operator):
return _OpChain.make((self, x))
return self.apply(x)
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
......@@ -78,7 +65,7 @@ class _FunctionApplier(Operator):
def target(self):
return self._domain
def __call__(self, x):
def apply(self, x):
return getattr(x, self._funcname)()
......@@ -117,7 +104,7 @@ class _OpChain(_CombinedOperator):
def target(self):
return self._ops[0].target
def __call__(self, x):
def apply(self, x):
for op in reversed(self._ops):
x = op(x)
return x
......@@ -135,7 +122,7 @@ class _OpProd(_CombinedOperator):
def target(self):
return self._ops[0].target
def __call__(self, x):
def apply(self, x):
from ..utilities import my_product
return my_product(map(lambda op: op(x), self._ops))
......@@ -154,7 +141,7 @@ class _OpSum(_CombinedOperator):
def target(self):
return self._target
def __call__(self, x):
def apply(self, x):
raise NotImplementedError
......@@ -193,7 +180,7 @@ class QuadraticFormOperator(Operator):
def target(self):
return self._target
def __call__(self, x):
def apply(self, x):
if isinstance(x, Linearization):
jac = self._op(x)
val = Field(self._target, 0.5 * x.vdot(jac))
......
......@@ -56,9 +56,9 @@ class SandwichOperator(EndomorphicOperator):
raise TypeError("cheese must be a linear operator")
if cheese is None:
cheese = ScalingOperator(1., bun.target)
op = bun.adjoint.chain(bun)
op = bun.adjoint(bun)
else:
op = bun.adjoint.chain(cheese).chain(bun)
op = bun.adjoint(cheese(bun))
# if our sandwich is diagonal, we can return immediately
if isinstance(op, (ScalingOperator, DiagonalOperator)):
......
......@@ -54,4 +54,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
if strength == 0.:
return ScalingOperator(0., domain)
laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
return (strength**2)*laplace.adjoint.chain(laplace)
return (strength**2)*laplace.adjoint(laplace)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment