Commit c3ed466f authored by Martin Reinecke's avatar Martin Reinecke

no more chains

parent 369c6e7c
...@@ -429,7 +429,7 @@ ...@@ -429,7 +429,7 @@
"mask[l:h] = 0\n", "mask[l:h] = 0\n",
"mask = ift.Field.from_global_data(s_space, mask)\n", "mask = ift.Field.from_global_data(s_space, mask)\n",
"\n", "\n",
"R = ift.DiagonalOperator(mask).chain(HT)\n", "R = ift.DiagonalOperator(mask)(HT)\n",
"n = n.to_global_data_rw()\n", "n = n.to_global_data_rw()\n",
"n[l:h] = 0\n", "n[l:h] = 0\n",
"n = ift.Field.from_global_data(s_space, n)\n", "n = ift.Field.from_global_data(s_space, n)\n",
...@@ -585,7 +585,7 @@ ...@@ -585,7 +585,7 @@
"mask[l:h,l:h] = 0.\n", "mask[l:h,l:h] = 0.\n",
"mask = ift.Field.from_global_data(s_space, mask)\n", "mask = ift.Field.from_global_data(s_space, mask)\n",
"\n", "\n",
"R = ift.DiagonalOperator(mask).chain(HT)\n", "R = ift.DiagonalOperator(mask)(HT)\n",
"n = n.to_global_data_rw()\n", "n = n.to_global_data_rw()\n",
"n[l:h, l:h] = 0\n", "n[l:h, l:h] = 0\n",
"n = ift.Field.from_global_data(s_space, n)\n", "n = ift.Field.from_global_data(s_space, n)\n",
......
...@@ -53,7 +53,7 @@ if __name__ == '__main__': ...@@ -53,7 +53,7 @@ if __name__ == '__main__':
A = pd(a) A = pd(a)
# Set up a sky model # Set up a sky model
sky = HT.chain(ift.makeOp(A)).positive_tanh() sky = HT(ift.makeOp(A)).positive_tanh()
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
# Set up instrumental response # Set up instrumental response
...@@ -61,7 +61,7 @@ if __name__ == '__main__': ...@@ -61,7 +61,7 @@ if __name__ == '__main__':
# Generate mock data # Generate mock data
d_space = R.target[0] d_space = R.target[0]
p = R.chain(sky) p = R(sky)
mock_position = ift.from_random('normal', harmonic_space) mock_position = ift.from_random('normal', harmonic_space)
pp = p(mock_position) pp = p(mock_position)
data = np.random.binomial(1, pp.to_global_data().astype(np.float64)) data = np.random.binomial(1, pp.to_global_data().astype(np.float64))
......
...@@ -78,7 +78,7 @@ if __name__ == '__main__': ...@@ -78,7 +78,7 @@ if __name__ == '__main__':
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
mask = ift.Field.from_global_data(position_space, mask) mask = ift.Field.from_global_data(position_space, mask)
Mask = ift.DiagonalOperator(mask) Mask = ift.DiagonalOperator(mask)
R = GR.chain(Mask).chain(HT) R = GR(Mask(HT))
data_space = GR.target data_space = GR.target
...@@ -93,7 +93,7 @@ if __name__ == '__main__': ...@@ -93,7 +93,7 @@ if __name__ == '__main__':
# Build propagator D and information source j # Build propagator D and information source j
j = R.adjoint_times(N.inverse_times(data)) j = R.adjoint_times(N.inverse_times(data))
D_inv = R.adjoint.chain(N.inverse).chain(R) + S.inverse D_inv = R.adjoint(N.inverse(R)) + S.inverse
# Make it invertible # Make it invertible
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse
...@@ -112,7 +112,7 @@ if __name__ == '__main__': ...@@ -112,7 +112,7 @@ if __name__ == '__main__':
title="getting_started_1") title="getting_started_1")
else: else:
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal') ift.plot(HT(MOCK_SIGNAL), title='Mock Signal')
ift.plot(mask_to_nan(mask, (GR.chain(Mask)).adjoint(data)), ift.plot(mask_to_nan(mask, (GR(Mask)).adjoint(data)),
title='Data') title='Data')
ift.plot(HT(m), title='Reconstruction') ift.plot(HT(m), title='Reconstruction')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals') ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals')
......
...@@ -70,16 +70,16 @@ if __name__ == '__main__': ...@@ -70,16 +70,16 @@ if __name__ == '__main__':
A = pd(a) A = pd(a)
# Set up a sky model # Set up a sky model
sky = ift.exp(HT.chain(ift.makeOp(A))) sky = ift.exp(HT(ift.makeOp(A)))
M = ift.DiagonalOperator(exposure) M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
# Set up instrumental response # Set up instrumental response
R = GR.chain(M) R = GR(M)
# Generate mock data # Generate mock data
d_space = R.target[0] d_space = R.target[0]
lamb = R.chain(sky) lamb = R(sky)
mock_position = ift.from_random('normal', domain) mock_position = ift.from_random('normal', domain)
data = lamb(mock_position) data = lamb(mock_position)
data = np.random.poisson(data.to_global_data().astype(np.float64)) data = np.random.poisson(data.to_global_data().astype(np.float64))
......
...@@ -44,8 +44,8 @@ if __name__ == '__main__': ...@@ -44,8 +44,8 @@ if __name__ == '__main__':
domain = ift.MultiDomain.union( domain = ift.MultiDomain.union(
(A.domain, ift.MultiDomain.make({'xi': harmonic_space}))) (A.domain, ift.MultiDomain.make({'xi': harmonic_space})))
correlated_field = ht.chain( correlated_field = ht(
power_distributor.chain(A)*ift.FieldAdapter(domain, "xi")) power_distributor(A)*ift.FieldAdapter(domain, "xi"))
# alternatively to the block above one can do: # alternatively to the block above one can do:
# correlated_field = ift.CorrelatedField(position_space, A) # correlated_field = ift.CorrelatedField(position_space, A)
...@@ -57,7 +57,7 @@ if __name__ == '__main__': ...@@ -57,7 +57,7 @@ if __name__ == '__main__':
R = ift.LOSResponse(position_space, starts=LOS_starts, R = ift.LOSResponse(position_space, starts=LOS_starts,
ends=LOS_ends) ends=LOS_ends)
# build signal response model and model likelihood # build signal response model and model likelihood
signal_response = R.chain(signal) signal_response = R(signal)
# specify noise # specify noise
data_space = R.target data_space = R.target
noise = .001 noise = .001
...@@ -69,7 +69,7 @@ if __name__ == '__main__': ...@@ -69,7 +69,7 @@ if __name__ == '__main__':
# set up model likelihood # set up model likelihood
likelihood = ift.GaussianEnergy( likelihood = ift.GaussianEnergy(
mean=data, covariance=N).chain(signal_response) mean=data, covariance=N)(signal_response)
# set up minimization and inversion schemes # set up minimization and inversion schemes
ic_cg = ift.GradientNormController(iteration_limit=10) ic_cg = ift.GradientNormController(iteration_limit=10)
......
...@@ -97,7 +97,7 @@ d = ift.from_global_data(d_space, y) ...@@ -97,7 +97,7 @@ d = ift.from_global_data(d_space, y)
N = ift.DiagonalOperator(ift.from_global_data(d_space, var)) N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
IC = ift.GradientNormController(tol_abs_gradnorm=1e-8) IC = ift.GradientNormController(tol_abs_gradnorm=1e-8)
likelihood = ift.GaussianEnergy(d, N).chain(R) likelihood = ift.GaussianEnergy(d, N)(R)
H = ift.Hamiltonian(likelihood, IC) H = ift.Hamiltonian(likelihood, IC)
H = ift.EnergyAdapter(params, H) H = ift.EnergyAdapter(params, H)
H = H.make_invertible(IC) H = H.make_invertible(IC)
......
...@@ -40,7 +40,7 @@ class Hamiltonian(Operator): ...@@ -40,7 +40,7 @@ class Hamiltonian(Operator):
def target(self): def target(self):
return DomainTuple.scalar_domain() return DomainTuple.scalar_domain()
def __call__(self, x): def apply(self, x):
if self._ic_samp is None or not isinstance(x, Linearization): if self._ic_samp is None or not isinstance(x, Linearization):
return self._lh(x) + self._prior(x) return self._lh(x) + self._prior(x)
else: else:
......
...@@ -42,6 +42,6 @@ class SampledKullbachLeiblerDivergence(Operator): ...@@ -42,6 +42,6 @@ class SampledKullbachLeiblerDivergence(Operator):
def target(self): def target(self):
return DomainTuple.scalar_domain() return DomainTuple.scalar_domain()
def __call__(self, x): def apply(self, x):
return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) * return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
(1./len(self._res_samples))) (1./len(self._res_samples)))
...@@ -130,7 +130,7 @@ class AmplitudeModel(Operator): ...@@ -130,7 +130,7 @@ class AmplitudeModel(Operator):
cepstrum = create_cepstrum_amplitude_field(dof_space, kern) cepstrum = create_cepstrum_amplitude_field(dof_space, kern)
ceps = makeOp(sqrt(cepstrum)) ceps = makeOp(sqrt(cepstrum))
self._smooth_op = sym.chain(qht).chain(ceps) self._smooth_op = sym(qht(ceps))
self._keys = tuple(keys) self._keys = tuple(keys)
@property @property
...@@ -141,7 +141,7 @@ class AmplitudeModel(Operator): ...@@ -141,7 +141,7 @@ class AmplitudeModel(Operator):
def target(self): def target(self):
return self._target return self._target
def __call__(self, x): def apply(self, x):
smooth_spec = self._smooth_op(x[self._keys[0]]) smooth_spec = self._smooth_op(x[self._keys[0]])
phi = x[self._keys[1]] + self._norm_phi_mean phi = x[self._keys[1]] + self._norm_phi_mean
linear_spec = self._slope(phi) linear_spec = self._slope(phi)
......
...@@ -39,7 +39,7 @@ class BernoulliEnergy(Operator): ...@@ -39,7 +39,7 @@ class BernoulliEnergy(Operator):
def target(self): def target(self):
return DomainTuple.scalar_domain() return DomainTuple.scalar_domain()
def __call__(self, x): def apply(self, x):
x = self._p(x) x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
......
...@@ -58,7 +58,7 @@ class CorrelatedField(Operator): ...@@ -58,7 +58,7 @@ class CorrelatedField(Operator):
def target(self): def target(self):
return self._ht.target return self._ht.target
def __call__(self, x): def apply(self, x):
A = self._power_distributor(self._amplitude_model(x)) A = self._power_distributor(self._amplitude_model(x))
correlated_field_h = A * x["xi"] correlated_field_h = A * x["xi"]
correlated_field = self._ht(correlated_field_h) correlated_field = self._ht(correlated_field_h)
......
...@@ -55,7 +55,7 @@ class GaussianEnergy(Operator): ...@@ -55,7 +55,7 @@ class GaussianEnergy(Operator):
def target(self): def target(self):
return DomainTuple.scalar_domain() return DomainTuple.scalar_domain()
def __call__(self, x): def apply(self, x):
residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean
icovres = residual if self._icov is None else self._icov(residual) icovres = residual if self._icov is None else self._icov(residual)
res = .5*residual.vdot(icovres) res = .5*residual.vdot(icovres)
......
...@@ -41,7 +41,7 @@ class PoissonianEnergy(Operator): ...@@ -41,7 +41,7 @@ class PoissonianEnergy(Operator):
def target(self): def target(self):
return DomainTuple.scalar_domain() return DomainTuple.scalar_domain()
def __call__(self, x): def apply(self, x):
x = self._op(x) x = self._op(x)
res = x.sum() - x.log().vdot(self._d) res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
......
...@@ -46,8 +46,8 @@ class Linearization(object): ...@@ -46,8 +46,8 @@ class Linearization(object):
def __neg__(self): def __neg__(self):
return Linearization( return Linearization(
-self._val, self._jac.chain(-1), -self._val, self._jac*(-1),
None if self._metric is None else self._metric.chain(-1)) None if self._metric is None else self._metric*(-1))
def __add__(self, other): def __add__(self, other):
if isinstance(other, Linearization): if isinstance(other, Linearization):
...@@ -77,24 +77,24 @@ class Linearization(object): ...@@ -77,24 +77,24 @@ class Linearization(object):
d2 = makeOp(other._val) d2 = makeOp(other._val)
return Linearization( return Linearization(
self._val*other._val, self._val*other._val,
d2.chain(self._jac) + d1.chain(other._jac)) d2(self._jac) + d1(other._jac))
if isinstance(other, (int, float, complex)): if isinstance(other, (int, float, complex)):
# if other == 0: # if other == 0:
# return ... # return ...
met = None if self._metric is None else self._metric.chain(other) met = None if self._metric is None else self._metric(other)
return Linearization(self._val*other, self._jac.chain(other), met) return Linearization(self._val*other, self._jac(other), met)
if isinstance(other, (Field, MultiField)): if isinstance(other, (Field, MultiField)):
d2 = makeOp(other) d2 = makeOp(other)
return Linearization(self._val*other, d2.chain(self._jac)) return Linearization(self._val*other, d2(self._jac))
raise TypeError raise TypeError
def __rmul__(self, other): def __rmul__(self, other):
from .sugar import makeOp from .sugar import makeOp
if isinstance(other, (int, float, complex)): if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac.chain(other)) return Linearization(self._val*other, self._jac(other))
if isinstance(other, (Field, MultiField)): if isinstance(other, (Field, MultiField)):
d1 = makeOp(other) d1 = makeOp(other)
return Linearization(self._val*other, d1.chain(self._jac)) return Linearization(self._val*other, d1(self._jac))
def vdot(self, other): def vdot(self, other):
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
...@@ -102,11 +102,11 @@ class Linearization(object): ...@@ -102,11 +102,11 @@ class Linearization(object):
if isinstance(other, (Field, MultiField)): if isinstance(other, (Field, MultiField)):
return Linearization( return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other)), Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
VdotOperator(other).chain(self._jac)) VdotOperator(other)(self._jac))
return Linearization( return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)), Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)),
VdotOperator(self._val).chain(other._jac) + VdotOperator(self._val)(other._jac) +
VdotOperator(other._val).chain(self._jac)) VdotOperator(other._val)(self._jac))
def sum(self): def sum(self):
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
...@@ -114,24 +114,24 @@ class Linearization(object): ...@@ -114,24 +114,24 @@ class Linearization(object):
from .sugar import full from .sugar import full
return Linearization( return Linearization(
Field(DomainTuple.scalar_domain(), self._val.sum()), Field(DomainTuple.scalar_domain(), self._val.sum()),
SumReductionOperator(self._jac.target).chain(self._jac)) SumReductionOperator(self._jac.target)(self._jac))
def exp(self): def exp(self):
tmp = self._val.exp() tmp = self._val.exp()
return Linearization(tmp, makeOp(tmp).chain(self._jac)) return Linearization(tmp, makeOp(tmp)(self._jac))
def log(self): def log(self):
tmp = self._val.log() tmp = self._val.log()
return Linearization(tmp, makeOp(1./self._val).chain(self._jac)) return Linearization(tmp, makeOp(1./self._val)(self._jac))
def tanh(self): def tanh(self):
tmp = self._val.tanh() tmp = self._val.tanh()
return Linearization(tmp, makeOp(1.-tmp**2).chain(self._jac)) return Linearization(tmp, makeOp(1.-tmp**2)(self._jac))
def positive_tanh(self): def positive_tanh(self):
tmp = self._val.tanh() tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp) tmp2 = 0.5*(1.+tmp)
return Linearization(tmp2, makeOp(0.5*(1.-tmp**2)).chain(self._jac)) return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def add_metric(self, metric): def add_metric(self, metric):
return Linearization(self._val, self._jac, metric) return Linearization(self._val, self._jac, metric)
......
...@@ -68,7 +68,7 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -68,7 +68,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
def _combine_chain(self, op): def _combine_chain(self, op):
if self._domain is not op._domain: if self._domain is not op._domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
res = tuple(v1.chain(v2) for v1, v2 in zip(self._ops, op._ops)) res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops))
return BlockDiagonalOperator(self._domain, res) return BlockDiagonalOperator(self._domain, res)
def _combine_sum(self, op, selfneg, opneg): def _combine_sum(self, op, selfneg, opneg):
......
...@@ -67,4 +67,4 @@ def HarmonicSmoothingOperator(domain, sigma, space=None): ...@@ -67,4 +67,4 @@ def HarmonicSmoothingOperator(domain, sigma, space=None):
ddom = list(domain) ddom = list(domain)
ddom[space] = codomain ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space) diag = DiagonalOperator(kernel, ddom, space)
return Hartley.inverse.chain(diag).chain(Hartley) return Hartley.inverse(diag(Hartley))
...@@ -142,9 +142,6 @@ class LinearOperator(Operator): ...@@ -142,9 +142,6 @@ class LinearOperator(Operator):
from .chain_operator import ChainOperator from .chain_operator import ChainOperator
return ChainOperator.make([self, other2]) return ChainOperator.make([self, other2])
def chain(self, other):
return self.__matmul__(other)
def __rmatmul__(self, other): def __rmatmul__(self, other):
if np.isscalar(other) and other == 1.: if np.isscalar(other) and other == 1.:
return self return self
...@@ -213,10 +210,14 @@ class LinearOperator(Operator): ...@@ -213,10 +210,14 @@ class LinearOperator(Operator):
def __call__(self, x): def __call__(self, x):
"""Same as :meth:`times`""" """Same as :meth:`times`"""
from ..field import Field
from ..multi.multi_field import MultiField
if isinstance(x, (Field, MultiField)):
return self.apply(x, self.TIMES)
from ..linearization import Linearization from ..linearization import Linearization
if isinstance(x, Linearization): if isinstance(x, Linearization):
return Linearization(self(x._val), self.chain(x._jac)) return Linearization(self(x._val), self(x._jac))
return self.apply(x, self.TIMES) return self.__matmul__(x)
def times(self, x): def times(self, x):
""" Applies the Operator to a given Field. """ Applies the Operator to a given Field.
......
...@@ -33,26 +33,13 @@ class Operator(NiftyMetaBase()): ...@@ -33,26 +33,13 @@ class Operator(NiftyMetaBase()):
return NotImplemented return NotImplemented
return _OpProd.make((self, x)) return _OpProd.make((self, x))
def chain(self, x): def apply(self, x):
res = self.__matmul__(x) raise NotImplementedError
if res == NotImplemented:
raise TypeError("operator expected")
return res
def __call__(self, x): def __call__(self, x):
"""Returns transformed x if isinstance(x, Operator):
return _OpChain.make((self, x))
Parameters return self.apply(x)
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise NotImplementedError
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
...@@ -78,7 +65,7 @@ class _FunctionApplier(Operator): ...@@ -78,7 +65,7 @@ class _FunctionApplier(Operator):
def target(self): def target(self):
return self._domain return self._domain
def __call__(self, x): def apply(self, x):
return getattr(x, self._funcname)() return getattr(x, self._funcname)()
...@@ -117,7 +104,7 @@ class _OpChain(_CombinedOperator): ...@@ -117,7 +104,7 @@ class _OpChain(_CombinedOperator):
def target(self): def target(self):
return self._ops[0].target return self._ops[0].target
def __call__(self, x): def apply(self, x):
for op in reversed(self._ops): for op in reversed(self._ops):
x = op(x) x = op(x)
return x return x
...@@ -135,7 +122,7 @@ class _OpProd(_CombinedOperator): ...@@ -135,7 +122,7 @@ class _OpProd(_CombinedOperator):
def target(self): def target(self):
return self._ops[0].target return self._ops[0].target
def __call__(self, x): def apply(self, x):
from ..utilities import my_product from ..utilities import my_product
return my_product(map(lambda op: op(x), self._ops)) return my_product(map(lambda op: op(x), self._ops))
...@@ -154,7 +141,7 @@ class _OpSum(_CombinedOperator): ...@@ -154,7 +141,7 @@ class _OpSum(_CombinedOperator):
def target(self): def target(self):
return self._target return self._target
def __call__(self, x): def apply(self, x):
raise NotImplementedError raise NotImplementedError
...@@ -193,7 +180,7 @@ class QuadraticFormOperator(Operator): ...@@ -193,7 +180,7 @@ class QuadraticFormOperator(Operator):
def target(self): def target(self):
return self._target return self._target
def __call__(self, x): def apply(self, x):
if isinstance(x, Linearization): if isinstance(x, Linearization):
jac = self._op(x) jac = self._op(x)
val = Field(self._target, 0.5