Commit 85b3b288 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Operators can no longer be chained by '*'; you need to use '.chain()' or '@' (in Python3)

parent b44405a4
%% Cell type:markdown id: tags:
# A NIFTy demonstration
%% Cell type:markdown id: tags:
## IFT: Big Picture
IFT starting point:
$$d = Rs+n$$
Typically, $s$ is a continuous field, $d$ a discrete data vector. Particularly, $R$ is not invertible.
IFT aims at **inverting** the above uninvertible problem in the **best possible way** using Bayesian statistics.
## NIFTy
NIFTy (Numerical Information Field Theory) is a Python framework in which IFT problems can be tackled easily.
Main Interfaces:
- **Spaces**: Cartesian, 2-Spheres (Healpix, Gauss-Legendre) and their respective harmonic spaces.
- **Fields**: Defined on spaces.
- **Operators**: Acting on fields.
%% Cell type:markdown id: tags:
## Wiener Filter: Formulae
### Assumptions
- $d=Rs+n$, $R$ linear operator.
- $\mathcal P (s) = \mathcal G (s,S)$, $\mathcal P (n) = \mathcal G (n,N)$ where $S, N$ are positive definite matrices.
### Posterior
The Posterior is given by:
$$\mathcal P (s|d) \propto P(s,d) = \mathcal G(d-Rs,N) \,\mathcal G(s,S) \propto \mathcal G (s-m,D) $$
where
$$\begin{align}
m &= Dj \\
D^{-1}&= (S^{-1} +R^\dagger N^{-1} R )\\
j &= R^\dagger N^{-1} d
\end{align}$$
Let us implement this in NIFTy!
%% Cell type:markdown id: tags:
## Wiener Filter: Example
- We assume statistical homogeneity and isotropy. Therefore the signal covariance $S$ is diagonal in harmonic space, and is described by a one-dimensional power spectrum, assumed here as $$P(k) = P_0\,\left(1+\left(\frac{k}{k_0}\right)^2\right)^{-\gamma /2},$$
with $P_0 = 0.2, k_0 = 5, \gamma = 4$.
- $N = 0.2 \cdot \mathbb{1}$.
- Number of data points $N_{pix} = 512$.
- reconstruction in harmonic space.
- Response operator:
$$R = FFT_{\text{harmonic} \rightarrow \text{position}}$$
%% Cell type:code id: tags:
``` python
N_pixels = 512 # Number of pixels
def pow_spec(k):
P0, k0, gamma = [.2, 5, 4]
return P0 / ((1. + (k/k0)**2)**(gamma / 2))
```
%% Cell type:markdown id: tags:
## Wiener Filter: Implementation
%% Cell type:markdown id: tags:
### Import Modules
%% Cell type:code id: tags:
``` python
import numpy as np
np.random.seed(40)
import nifty5 as ift
import matplotlib.pyplot as plt
%matplotlib inline
```
%% Cell type:markdown id: tags:
### Implement Propagator
%% Cell type:code id: tags:
``` python
def Curvature(R, N, Sh):
IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1)
# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# helper methods.
return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)
```
%% Cell type:markdown id: tags:
### Conjugate Gradient Preconditioning
- $D$ is defined via:
$$D^{-1} = \mathcal S_h^{-1} + R^\dagger N^{-1} R.$$
In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: *Conjugate Gradient*).
<!--
- One can define the *condition number* of a non-singular and normal matrix $A$:
$$\kappa (A) := \frac{|\lambda_{\text{max}}|}{|\lambda_{\text{min}}|},$$
where $\lambda_{\text{max}}$ and $\lambda_{\text{min}}$ are the largest and smallest eigenvalue of $A$, respectively.
- The larger $\kappa$ the slower Conjugate Gradient.
- By default, conjugate gradient solves: $D^{-1} m = j$ for $m$, where $D^{-1}$ can be badly conditioned. If one knows a non-singular matrix $T$ for which $TD^{-1}$ is better conditioned, one can solve the equivalent problem:
$$\tilde A m = \tilde j,$$
where $\tilde A = T D^{-1}$ and $\tilde j = Tj$.
- In our case $S^{-1}$ is responsible for the bad conditioning of $D$ depending on the chosen power spectrum. Thus, we choose
$$T = \mathcal F^\dagger S_h^{-1} \mathcal F.$$
-->
%% Cell type:markdown id: tags:
### Generate Mock data
- Generate a field $s$ and $n$ with given covariances.
- Calculate $d$.
%% Cell type:code id: tags:
``` python
s_space = ift.RGSpace(N_pixels)
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space, target=s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
R = HT #*ift.create_harmonic_smoothing_operator((h_space,), 0, 0.02)
# Fields and data
sh = Sh.draw_sample()
noiseless_data=R(sh)
noise_amplitude = np.sqrt(0.2)
N = ift.ScalingOperator(noise_amplitude**2, s_space)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
d = noiseless_data + n
j = R.adjoint_times(N.inverse_times(d))
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
```
%% Cell type:markdown id: tags:
### Run Wiener Filter
%% Cell type:code id: tags:
``` python
m = D(j)
```
%% Cell type:markdown id: tags:
### Signal Reconstruction
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = HT(sh).to_global_data()
m_data = HT(m).to_global_data()
d_data = d.to_global_data()
plt.figure(figsize=(15,10))
plt.plot(s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction",linewidth=3)
plt.title("Reconstruction")
plt.legend()
plt.show()
```
%% Cell type:code id: tags:
``` python
plt.figure(figsize=(15,10))
plt.plot(s_data - s_data, 'r', label="Signal", linewidth=3)
plt.plot(d_data - s_data, 'k.', label="Data")
plt.plot(m_data - s_data, 'k', label="Reconstruction",linewidth=3)
plt.axhspan(-noise_amplitude,noise_amplitude, facecolor='0.9', alpha=.5)
plt.title("Residuals")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
### Power Spectrum
%% Cell type:code id: tags:
``` python
s_power_data = ift.power_analyze(sh).to_global_data()
m_power_data = ift.power_analyze(m).to_global_data()
plt.figure(figsize=(15,10))
plt.loglog()
plt.xlim(1, int(N_pixels/2))
ymin = min(m_power_data)
plt.ylim(ymin, 1)
xs = np.arange(1,int(N_pixels/2),.1)
plt.plot(xs, pow_spec(xs), label="True Power Spectrum", color='k',alpha=0.5)
plt.plot(s_power_data, 'r', label="Signal")
plt.plot(m_power_data, 'k', label="Reconstruction")
plt.axhline(noise_amplitude**2 / N_pixels, color="k", linestyle='--', label="Noise level", alpha=.5)
plt.axhspan(noise_amplitude**2 / N_pixels, ymin, facecolor='0.9', alpha=.5)
plt.title("Power Spectrum")
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
## Wiener Filter on Incomplete Data
%% Cell type:code id: tags:
``` python
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(noise_amplitude**2,s_space)
# R is defined below
# Fields
sh = Sh.draw_sample()
s = HT(sh)
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=noise_amplitude, mean=0)
```
%% Cell type:markdown id: tags:
### Partially Lose Data
%% Cell type:code id: tags:
``` python
l = int(N_pixels * 0.2)
h = int(N_pixels * 0.2 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h] = 0
mask = ift.Field.from_global_data(s_space, mask)
R = ift.DiagonalOperator(mask)*HT
R = ift.DiagonalOperator(mask).chain(HT)
n = n.to_global_data().copy()
n[l:h] = 0
n = ift.Field.from_global_data(s_space, n)
d = R(sh) + n
```
%% Cell type:code id: tags:
``` python
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
j = R.adjoint_times(N.inverse_times(d))
m = D(j)
```
%% Cell type:markdown id: tags:
### Compute Uncertainty
%% Cell type:code id: tags:
``` python
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 200)
```
%% Cell type:markdown id: tags:
### Get data
%% Cell type:code id: tags:
``` python
# Get signal data and reconstruction data
s_data = s.to_global_data()
m_data = HT(m).to_global_data()
m_var_data = m_var.to_global_data()
uncertainty = np.sqrt(m_var_data)
d_data = d.to_global_data().copy()
# Set lost data to NaN for proper plotting
d_data[d_data == 0] = np.nan
```
%% Cell type:code id: tags:
``` python
fig = plt.figure(figsize=(15,10))
plt.axvspan(l, h, facecolor='0.8',alpha=0.5)
plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0.5', alpha=0.5)
plt.plot(s_data, 'r', label="Signal", alpha=1, linewidth=3)
plt.plot(d_data, 'k.', label="Data")
plt.plot(m_data, 'k', label="Reconstruction", linewidth=3)
plt.title("Reconstruction of incomplete data")
plt.legend()
```
%% Cell type:markdown id: tags:
# 2d Example
%% Cell type:code id: tags:
``` python
N_pixels = 256 # Number of pixels
sigma2 = 2. # Noise variance
def pow_spec(k):
P0, k0, gamma = [.2, 2, 4]
return P0 * (1. + (k/k0)**2)**(-gamma/2)
s_space = ift.RGSpace([N_pixels, N_pixels])
```
%% Cell type:code id: tags:
``` python
h_space = s_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_space,s_space)
# Operators
Sh = ift.create_power_operator(h_space, power_spectrum=pow_spec)
N = ift.ScalingOperator(sigma2,s_space)
# Fields and data
sh = Sh.draw_sample()
n = ift.Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
# Lose some data
l = int(N_pixels * 0.33)
h = int(N_pixels * 0.33 * 2)
mask = np.full(s_space.shape, 1.)
mask[l:h,l:h] = 0.
mask = ift.Field.from_global_data(s_space, mask)
R = ift.DiagonalOperator(mask)*HT
R = ift.DiagonalOperator(mask).chain(HT)
n = n.to_global_data().copy()
n[l:h, l:h] = 0
n = ift.Field.from_global_data(s_space, n)
curv = Curvature(R=R, N=N, Sh=Sh)
D = curv.inverse
d = R(sh) + n
j = R.adjoint_times(N.inverse_times(d))
# Run Wiener filter
m = D(j)
# Uncertainty
m_mean, m_var = ift.probe_with_posterior_samples(curv, HT, 20)
# Get data
s_data = HT(sh).to_global_data()
m_data = HT(m).to_global_data()
m_var_data = m_var.to_global_data()
d_data = d.to_global_data()
uncertainty = np.sqrt(np.abs(m_var_data))
```
%% Cell type:code id: tags:
``` python
cm = ['magma', 'inferno', 'plasma', 'viridis'][1]
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
data = [s_data, d_data]
caption = ["Signal", "Data"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi,
vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:code id: tags:
``` python
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(3, 2, figsize=(15, 22.5))
sample = HT(curv.draw_sample(from_inverse=True)+m).to_global_data()
post_mean = (m_mean + HT(m)).to_global_data()
data = [s_data, m_data, post_mean, sample, s_data - m_data, uncertainty]
caption = ["Signal", "Reconstruction", "Posterior mean", "Sample", "Residuals", "Uncertainty Map"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi, vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
```
%% Cell type:markdown id: tags:
### Is the uncertainty map reliable?
%% Cell type:code id: tags:
``` python
precise = (np.abs(s_data-m_data) < uncertainty)
print("Error within uncertainty map bounds: " + str(np.sum(precise) * 100 / N_pixels**2) + "%")
plt.figure(figsize=(15,10))
plt.imshow(precise.astype(float), cmap="brg")
plt.colorbar()
```
%% Cell type:markdown id: tags:
# Start Coding
## NIFTy Repository + Installation guide
https://gitlab.mpcdf.mpg.de/ift/NIFTy
NIFTy v5 **more or less stable!**
......
......@@ -78,7 +78,7 @@ if __name__ == '__main__':
GR = ift.GeometryRemover(position_space)
mask = ift.Field.from_global_data(position_space, mask)
Mask = ift.DiagonalOperator(mask)
R = GR * Mask * HT
R = GR.chain(Mask).chain(HT)
data_space = GR.target
......@@ -93,7 +93,7 @@ if __name__ == '__main__':
# Build propagator D and information source j
j = R.adjoint_times(N.inverse_times(data))
D_inv = R.adjoint * N.inverse * R + S.inverse
D_inv = R.adjoint.chain(N.inverse).chain(R) + S.inverse
# Make it invertible
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse
......@@ -112,7 +112,8 @@ if __name__ == '__main__':
title="getting_started_1")
else:
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal')
ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)), title='Data')
ift.plot(mask_to_nan(mask, (GR.chain(Mask)).adjoint(data)),
title='Data')
ift.plot(HT(m), title='Reconstruction')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)))
ift.plot_finish(nx=2, ny=2, xsize=10, ysize=10,
......
......@@ -75,7 +75,7 @@ if __name__ == '__main__':
M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space)
# Set up instrumental response
R = GR * M
R = GR.chain(M)
# Generate mock data
d_space = R.target[0]
......
......@@ -23,7 +23,7 @@ class EnergyAdapter(Energy):
@property
def value(self):
if self._val is None:
self._val = self._op(self._position)
self._val = self._op(self._position)
return self._val
@property
......
......@@ -130,7 +130,7 @@ class AmplitudeModel(Operator):
cepstrum = create_cepstrum_amplitude_field(dof_space, kern)
ceps = makeOp(sqrt(cepstrum))
self._smooth_op = sym * qht * ceps
self._smooth_op = sym.chain(qht).chain(ceps)
self._keys = tuple(keys)
@property
......
......@@ -47,8 +47,8 @@ class Linearization(object):
def __neg__(self):
return Linearization(
-self._val, self._jac*(-1),
None if self._metric is None else self._metric*(-1))
-self._val, self._jac.chain(-1),
None if self._metric is None else self._metric.chain(-1))
def __add__(self, other):
if isinstance(other, Linearization):
......@@ -79,47 +79,49 @@ class Linearization(object):
d2 = makeOp(other._val)
return Linearization(
self._val*other._val,
RelaxedSumOperator((d2*self._jac, d1*other._jac)))
RelaxedSumOperator((d2.chain(self._jac),
d1.chain(other._jac))))
if isinstance(other, (int, float, complex)):
# if other == 0:
# return ...
met = None if self._metric is None else self._metric*other
return Linearization(self._val*other, self._jac*other, met)
met = None if self._metric is None else self._metric.chain(other)
return Linearization(self._val*other, self._jac.chain(other), met)
if isinstance(other, (Field, MultiField)):
d2 = makeOp(other)
return Linearization(self._val*other, d2*self._jac)
return Linearization(self._val*other, d2.chain(self._jac))
raise TypeError
def __rmul__(self, other):
from .sugar import makeOp
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac*other)
return Linearization(self._val*other, self._jac.chain(other))
if isinstance(other, (Field, MultiField)):
d1 = makeOp(other)
return Linearization(self._val*other, d1*self._jac)
return Linearization(self._val*other, d1.chain(self._jac))
def sum(self):
from .sugar import full
from .operators.vdot_operator import VdotOperator
return Linearization(full((), self._val.sum()),
VdotOperator(full(self._jac.target, 1))*self._jac)
return Linearization(
full((), self._val.sum()),
VdotOperator(full(self._jac.target, 1)).chain(self._jac))
def exp(self):
tmp = self._val.exp()
return Linearization(tmp, makeOp(tmp)*self._jac)
return Linearization(tmp, makeOp(tmp).chain(self._jac))
def log(self):
tmp = self._val.log()
return Linearization(tmp, makeOp(1./self._val)*self._jac)
return Linearization(tmp, makeOp(1./self._val).chain(self._jac))
def tanh(self):
tmp = self._val.tanh()
return Linearization(tmp, makeOp(1.-tmp**2)*self._jac)
return Linearization(tmp, makeOp(1.-tmp**2).chain(self._jac))
def positive_tanh(self):
tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp)
return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))*self._jac)
return Linearization(tmp2, makeOp(0.5*(1.-tmp**2)).chain(self._jac))
def add_metric(self, metric):
return Linearization(self._val, self._jac, metric)
......
......@@ -68,7 +68,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op):
if self._domain is not op._domain:
raise ValueError("domain mismatch")
res = tuple(v1*v2 for v1, v2 in zip(self._ops, op._ops))
res = tuple(v1.chain(v2) for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg):
......
......@@ -67,4 +67,4 @@ def HarmonicSmoothingOperator(domain, sigma, space=None):
ddom = list(domain)
ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space)
return Hartley.inverse*diag*Hartley
return Hartley.inverse.chain(diag).chain(Hartley)
......@@ -117,18 +117,46 @@ class LinearOperator(Operator):
def __mul__(self, other):
from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
if not np.isscalar(other):
return Operator.__mul__(self, other)
if other == 1.:
return self
other = self._toOperator(other, self.domain)
from .scaling_operator import ScalingOperator
other = ScalingOperator(other, self.domain)
return ChainOperator.make([self, other])
def __rmul__(self, other):
from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
if not np.isscalar(other):
return Operator.__rmul__(self, other)
if other == 1.:
return self
other = self._toOperator(other, self.target)
from .scaling_operator import ScalingOperator
other = ScalingOperator(other, self.target)
return ChainOperator.make([other, self])
def __matmul__(self, other):
if np.isscalar(other) and other == 1.:
return self
other2 = self._toOperator(other, self.domain)
if other2 == NotImplemented:
return Operator.__matmul__(self, other)
from .chain_operator import ChainOperator
return ChainOperator.make([self, other2])
def chain(self, other):
return self.__matmul__(other)
def __rmatmul__(self, other):
if np.isscalar(other) and other == 1.:
return self
other2 = self._toOperator(other, self.target)
if other2 == NotImplemented:
from .chain_operator import ChainOperator
return Operator.__rmatmul__(self, other)
from .chain_operator import ChainOperator
return ChainOperator.make([other2, self])
def __add__(self, other):
from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
......@@ -190,7 +218,7 @@ class LinearOperator(Operator):
"""Same as :meth:`times`"""
from ..linearization import Linearization
if isinstance(x, Linearization):
return Linearization(self(x._val), self*x._jac)
return Linearization(self(x._val), self.chain(x._jac))
return self.apply(x, self.TIMES)
def times(self, x):
......
......@@ -56,9 +56,9 @@ class SandwichOperator(EndomorphicOperator):
raise TypeError("cheese must be a linear operator")
if cheese is None:
cheese = ScalingOperator(1., bun.target)
op = bun.adjoint*bun
op = bun.adjoint.chain(bun)
else:
op = bun.adjoint*cheese*bun
op = bun.adjoint.chain(cheese).chain(bun)
# if our sandwich is diagonal, we can return immediately
if isinstance(op, (ScalingOperator, DiagonalOperator)):
......
......@@ -54,4 +54,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
if strength == 0.:
return ScalingOperator(0., domain)
laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
return (strength**2)*laplace.adjoint*laplace
return (strength**2)*laplace.adjoint.chain(laplace)
......@@ -66,7 +66,7 @@ class Test_Functionality(unittest.TestCase):
op1 = ift.create_power_operator((space1, space2), _spec1, 0)
op2 = ift.create_power_operator((space1, space2), _spec2, 1)
opfull = op2*op1
opfull = op2.chain(op1)
samples = 500
sc1 = ift.StatCalculator()
......@@ -94,7 +94,7 @@ class Test_Functionality(unittest.TestCase):
S_1 = ift.create_power_operator((space1, space2), _spec1, 0)
S_2 = ift.create_power_operator((space1, space2), _spec2, 1)
S_full = S_2*S_1
S_full = S_2.chain(S_1)
samples = 500
sc1 = ift.StatCalculator()
......
......@@ -36,33 +36,6 @@ class Model_Tests(unittest.TestCase):
return ift.Linearization.make_var(s)
raise ValueError('unknown type passed')
def make_model(self, type, **kwargs):
if type == 'Constant':
np.random.seed(kwargs['seed'])
S = ift.ScalingOperator(1., kwargs['space'])
s = S.draw_sample()
return ift.Constant(
ift.MultiField.from_dict({kwargs['space_key']: s}),
ift.MultiField.from_dict({kwargs['space_key']: s}))
elif type == 'Variable':
np.random.seed(kwargs['seed'])
S = ift.ScalingOperator(1., kwargs['space'])
s = S.draw_sample()
return ift.Variable(
ift.MultiField.from_dict({kwargs['space_key']: s}))
elif type == 'LinearModel':
return ift.LinearModel(
inp=kwargs['model'], lin_op=kwargs['lin_op'])
else:
raise ValueError('unknown type passed')
def make_linear_operator(self, type, **kwargs):
if type == 'ScalingOperator':
lin_op = ift.ScalingOperator(1., kwargs['space'])
else:
raise ValueError('unknown type passed')
return lin_op
@expand(product(
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
......@@ -71,7 +44,7 @@ class Model_Tests(unittest.TestCase):
))
def testBasics(self, space, seed):
var = self.make_linearization("Variable", space, seed)
model = lambda inp: inp
model = ift.ScalingOperator(6., var.target)
ift.extra.check_value_gradient_consistency(model, var.val)
@expand(product(
......@@ -89,17 +62,17 @@ class Model_Tests(unittest.TestCase):
lin2 = self.make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2))
model = lambda inp: inp["s1"]*inp["s2"]
model = ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: inp["s1"]+inp["s2"]
model = ift.FieldAdapter(dom, "s1")+ift.FieldAdapter(dom, "s2")
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: inp["s1"]*3.
pos = ift.from_random("normal", dom1)
model = ift.FieldAdapter(dom, "s1")*3.
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: ift.ScalingOperator(2.456, space)(
inp["s1"]*inp["s2"])
model = ift.ScalingOperator(2.456, space).chain(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos)
model = lambda inp: ift.ScalingOperator(2.456, space)(
......
......@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
def test_blockdiagonal(self):
op = ift.BlockDiagonalOperator(
dom, (ift.ScalingOperator(20., dom["d1"]),))
op2 = op*op
op2 = op.chain(op)
ift.extra.consistency_check(op2)
assert_equal(type(op2), ift.BlockDiagonalOperator)
f1 = op2(ift.full(dom, 1))
......
......@@ -53,7 +53,7 @@ class Consistency_Tests(unittest.TestCase):
dtype=dtype))
op = ift.SandwichOperator.make(a, b)
ift.extra.consistency_check(op, dtype, dtype)
op = a*b
op = a.chain(b)
ift.extra.consistency_check(op, dtype, dtype)
op = a+b
ift.extra.consistency_check(op, dtype, dtype)
......
......@@ -37,7 +37,7 @@ class ComposedOperator_Tests(unittest.TestCase):
op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
op = op2*op1
op = op2.chain(op1)
rand1 = ift.Field.from_random('normal', domain=(space1, space2))
rand2 = ift.Field.from_random('normal', domain=(space1, space2))
......@@ -54,7 +54,7 @@ class ComposedOperator_Tests(unittest.TestCase):
op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
op = op2*op1
op = op2.chain(op1)
rand1 = ift.Field.from_random('normal', domain=(space1, space2))
tt1 = op.inverse_times(op.times(rand1))
......@@ -75,7 +75,8 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_chain(self, space):
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2
full_op = (op1.chain(op2).chain(op2).chain(op1).
chain(op1).chain(op1).chain(op2))
x = ift.Field.full(space, 1.)
res = full_op(x)
assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
......@@ -85,7 +86,7 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_mix(self, space):
op1 = ift.makeOp(ift.Field.full(space, 2.))
op2 = 3.
full_op = op1 * (op2 + op2) * op1 * op1 - op1 * op2
full_op = op1.chain(op2 + op2).chain(op1).chain(op1) - op1.chain(op2)
x = ift.Field.full(space, 1.)
res = full_op(x)
assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment