Commit a8f5fcd1 authored by Philipp Arras's avatar Philipp Arras

New apply logic

parent a778ee76
Pipeline #70576 failed with stages
in 13 minutes and 4 seconds
...@@ -15,7 +15,11 @@ from .multi_domain import MultiDomain ...@@ -15,7 +15,11 @@ from .multi_domain import MultiDomain
from .field import Field from .field import Field
from .multi_field import MultiField from .multi_field import MultiField
from .linearization import Linearization
from .operators.operator import Operator from .operators.operator import Operator
from .operators.linear_operator import LinearOperator
from .operators.adder import Adder from .operators.adder import Adder
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor from .operators.distributors import DOFDistributor, PowerDistributor
...@@ -28,7 +32,6 @@ from .operators.harmonic_operators import ( ...@@ -28,7 +32,6 @@ from .operators.harmonic_operators import (
HarmonicSmoothingOperator) HarmonicSmoothingOperator)
from .operators.field_zero_padder import FieldZeroPadder from .operators.field_zero_padder import FieldZeroPadder
from .operators.inversion_enabler import InversionEnabler from .operators.inversion_enabler import InversionEnabler
from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator from .operators.mask_operator import MaskOperator
from .operators.regridding_operator import RegriddingOperator from .operators.regridding_operator import RegriddingOperator
from .operators.sampling_enabler import SamplingEnabler from .operators.sampling_enabler import SamplingEnabler
...@@ -87,8 +90,6 @@ from .utilities import memo, frozendict ...@@ -87,8 +90,6 @@ from .utilities import memo, frozendict
from .logger import logger from .logger import logger
from .linearization import Linearization
from .operator_spectrum import operator_spectrum from .operator_spectrum import operator_spectrum
# We deliberately don't set __all__ here, because we don't want people to do a # We deliberately don't set __all__ here, because we don't want people to do a
......
...@@ -114,7 +114,7 @@ def _actual_domain_check_nonlinear(op, loc): ...@@ -114,7 +114,7 @@ def _actual_domain_check_nonlinear(op, loc):
assert_(reslin.jac.target is reslin.target) assert_(reslin.jac.target is reslin.target)
_actual_domain_check_linear(reslin.jac, inp=loc) _actual_domain_check_linear(reslin.jac, inp=loc)
_actual_domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc)) _actual_domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc))
if wm: if reslin.metric is not None:
assert_(reslin.metric.domain is reslin.metric.target) assert_(reslin.metric.domain is reslin.metric.target)
assert_(reslin.metric.domain is op.domain) assert_(reslin.metric.domain is op.domain)
...@@ -153,7 +153,7 @@ def _performance_check(op, pos, raise_on_fail): ...@@ -153,7 +153,7 @@ def _performance_check(op, pos, raise_on_fail):
cond.append(cop.count != 3) cond.append(cop.count != 3)
lin.jac.adjoint(lin.val) lin.jac.adjoint(lin.val)
cond.append(cop.count != 4) cond.append(cop.count != 4)
if wm and myop.target is DomainTuple.scalar_domain(): if lin.metric is not None:
lin.metric(pos) lin.metric(pos)
cond.append(cop.count != 6) cond.append(cop.count != 6)
if any(cond): if any(cond):
......
...@@ -25,6 +25,7 @@ from ..domain_tuple import DomainTuple ...@@ -25,6 +25,7 @@ from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field from ..field import Field
from ..linearization import Linearization
from ..logger import logger from ..logger import logger
from ..multi_field import MultiField from ..multi_field import MultiField
from ..operators.adder import Adder from ..operators.adder import Adder
...@@ -221,14 +222,15 @@ class _Normalization(Operator): ...@@ -221,14 +222,15 @@ class _Normalization(Operator):
self._mode_multiplicity = makeOp(makeField(self._domain, mode_multiplicity)) self._mode_multiplicity = makeOp(makeField(self._domain, mode_multiplicity))
self._specsum = _SpecialSum(self._domain, space) self._specsum = _SpecialSum(self._domain, space)
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
fa = FieldAdapter(self._domain, 'foo') if difforder >= self.WITH_JAC:
amp = fa.exp() x = Linearization.make_var(x, difforder == self.WITH_METRIC)
spec = (2*fa).exp() amp = x.exp()
spec = (2*x).exp()
# FIXME This normalizes also the zeromode which is supposed to be left # FIXME This normalizes also the zeromode which is supposed to be left
# untouched by this operator # untouched by this operator
return (self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp)(fa.adjoint(x)) return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp
class _SpecialSum(EndomorphicOperator): class _SpecialSum(EndomorphicOperator):
......
...@@ -34,17 +34,8 @@ def _float_or_listoffloat(inp): ...@@ -34,17 +34,8 @@ def _float_or_listoffloat(inp):
return [float(x) for x in inp] if isinstance(inp, list) else float(inp) return [float(x) for x in inp] if isinstance(inp, list) else float(inp)
def _make_dynamic_operator(target, def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, causal,
harmonic_padding, minimum_phase, sigc=None, quant=None, codomain=None):
sm_s0,
sm_x0,
cone,
keys,
causal,
minimum_phase,
sigc=None,
quant=None,
codomain=None):
if not isinstance(target, RGSpace): if not isinstance(target, RGSpace):
raise TypeError("RGSpace required") raise TypeError("RGSpace required")
if not target.harmonic: if not target.harmonic:
...@@ -128,7 +119,7 @@ def _make_dynamic_operator(target, ...@@ -128,7 +119,7 @@ def _make_dynamic_operator(target,
scaling = DiagonalOperator(Field(c.target, scaling)) scaling = DiagonalOperator(Field(c.target, scaling))
ops['lightspeed'] = scaling(lightspeed) ops['lightspeed'] = scaling(lightspeed)
c = LightConeOperator(c.target, m.target, quant)(c.exp()) c = LightConeOperator(c.target, m.target, quant) @ c.exp()
ops['light_cone'] = c ops['light_cone'] = c
m = c*m m = c*m
...@@ -139,13 +130,7 @@ def _make_dynamic_operator(target, ...@@ -139,13 +130,7 @@ def _make_dynamic_operator(target,
return m, ops return m, ops
def dynamic_operator(*, def dynamic_operator(*, target, harmonic_padding, sm_s0, sm_x0, key, causal=True,
target,
harmonic_padding,
sm_s0,
sm_x0,
key,
causal=True,
minimum_phase=False): minimum_phase=False):
"""Constructs an operator encoding the Green's function of a linear """Constructs an operator encoding the Green's function of a linear
homogeneous dynamic system. homogeneous dynamic system.
...@@ -206,17 +191,8 @@ def dynamic_operator(*, ...@@ -206,17 +191,8 @@ def dynamic_operator(*,
return _make_dynamic_operator(**dct) return _make_dynamic_operator(**dct)
def dynamic_lightcone_operator(*, def dynamic_lightcone_operator(*, target, harmonic_padding, sm_s0, sm_x0, key, lightcone_key,
target, sigc, quant, causal=True, minimum_phase=False):
harmonic_padding,
sm_s0,
sm_x0,
key,
lightcone_key,
sigc,
quant,
causal=True,
minimum_phase=False):
'''Extends the functionality of :func:`dynamic_operator` to a Green's '''Extends the functionality of :func:`dynamic_operator` to a Green's
function which is constrained to be within a light cone. function which is constrained to be within a light cone.
......
...@@ -131,12 +131,10 @@ class LightConeOperator(Operator): ...@@ -131,12 +131,10 @@ class LightConeOperator(Operator):
self._target = DomainTuple.make(target) self._target = DomainTuple.make(target)
self._sigx = sigx self._sigx = sigx
def apply(self, x): def apply(self, x, difforder):
islin = isinstance(x, Linearization) a, derivs = _cone_arrays(x.val, self.target, self._sigx, difforder >= self.WITH_JAC)
val = x.val.val if islin else x.val
a, derivs = _cone_arrays(val, self.target, self._sigx, islin)
res = Field(self.target, a) res = Field(self.target, a)
if not islin: if difforder == self.VALUE_ONLY:
return res return res
jac = _LightConeDerivative(x.jac.target, self.target, derivs)(x.jac) jac = _LightConeDerivative(self._domain, self._target, derivs)
return Linearization(res, jac, want_metric=x.want_metric) return Linearization(res, jac)
...@@ -38,19 +38,17 @@ class _InterpolationOperator(Operator): ...@@ -38,19 +38,17 @@ class _InterpolationOperator(Operator):
self._deriv = (self._table[1:]-self._table[:-1]) / self._d self._deriv = (self._table[1:]-self._table[:-1]) / self._d
self._inv_table_func = inverse_table_func self._inv_table_func = inverse_table_func
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
lin = isinstance(x, Linearization) val = (np.clip(x.val, self._xmin, self._xmax) - self._xmin) / self._d
val = x.val.val if lin else x.val
val = (np.clip(val, self._xmin, self._xmax) - self._xmin) / self._d
fi = np.floor(val).astype(int) fi = np.floor(val).astype(int)
w = val - fi w = val - fi
res = self._inv_table_func((1-w)*self._table[fi] + w*self._table[fi+1]) res = self._inv_table_func((1-w)*self._table[fi] + w*self._table[fi+1])
resfld = Field(self._domain, res) resfld = Field(self._domain, res)
if not lin: if difforder == self.VALUE_ONLY:
return resfld return resfld
jac = makeOp(Field(self._domain, self._deriv[fi]*res)) @ x.jac jac = makeOp(Field(self._domain, self._deriv[fi]*res))
return x.new(resfld, jac) return Linearization(resfld, jac)
def InverseGammaOperator(domain, alpha, q, delta=0.001): def InverseGammaOperator(domain, alpha, q, delta=0.001):
......
...@@ -63,6 +63,13 @@ class Linearization(object): ...@@ -63,6 +63,13 @@ class Linearization(object):
""" """
return Linearization(val, jac, metric, self._want_metric) return Linearization(val, jac, metric, self._want_metric)
def prepend_jac(self, jac):
metric = None
if self._metric is not None:
from .operators.sandwich_operator import SandwichOperator
metric = None if self._metric is None else SandwichOperator.make(jac, self._metric)
return self.new(self._val, self._jac @ jac, metric)
@property @property
def domain(self): def domain(self):
"""DomainTuple or MultiDomain : the Jacobian's domain""" """DomainTuple or MultiDomain : the Jacobian's domain"""
......
...@@ -18,9 +18,10 @@ ...@@ -18,9 +18,10 @@
import numpy as np import numpy as np
from ..field import Field from ..field import Field
from ..linearization import Linearization
from ..multi_field import MultiField from ..multi_field import MultiField
from .operator import Operator
from ..sugar import makeDomain from ..sugar import makeDomain
from .operator import Operator
class Adder(Operator): class Adder(Operator):
...@@ -42,8 +43,10 @@ class Adder(Operator): ...@@ -42,8 +43,10 @@ class Adder(Operator):
self._domain = self._target = dom self._domain = self._target = dom
self._neg = bool(neg) self._neg = bool(neg)
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
if self._neg: if self._neg:
return x - self._a return x - self._a
return x + self._a return x + self._a
...@@ -58,13 +58,13 @@ class Squared2NormOperator(EnergyOperator): ...@@ -58,13 +58,13 @@ class Squared2NormOperator(EnergyOperator):
def __init__(self, domain): def __init__(self, domain):
self._domain = domain self._domain = domain
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
if isinstance(x, Linearization): res = Field.scalar(x.vdot(x))
val = Field.scalar(x.val.vdot(x.val)) if difforder == self.VALUE_ONLY:
jac = VdotOperator(2*x.val)(x.jac) return res
return x.new(val, jac) jac = VdotOperator(2*x)
return Field.scalar(x.vdot(x)) return Linearization(res, jac, want_metric=difforder == self.WITH_METRIC)
class QuadraticFormOperator(EnergyOperator): class QuadraticFormOperator(EnergyOperator):
...@@ -87,14 +87,13 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -87,14 +87,13 @@ class QuadraticFormOperator(EnergyOperator):
self._op = endo self._op = endo
self._domain = endo.domain self._domain = endo.domain
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
if isinstance(x, Linearization): t1 = self._op(x)
t1 = self._op(x.val) res = Field.scalar(0.5*x.vdot(t1))
jac = VdotOperator(t1)(x.jac) if difforder == self.VALUE_ONLY:
val = Field.scalar(0.5*x.val.vdot(t1)) return res
return x.new(val, jac) return Linearization(res, VdotOperator(t1))
return Field.scalar(0.5*x.vdot(self._op(x)))
class VariableCovarianceGaussianEnergy(EnergyOperator): class VariableCovarianceGaussianEnergy(EnergyOperator):
...@@ -128,19 +127,17 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -128,19 +127,17 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
dom = DomainTuple.make(domain) dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._r: dom, self._icov: dom}) self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
lin = isinstance(x, Linearization) if difforder >= self.WITH_JAC:
r = FieldAdapter(self._domain[self._r], self._r) x = Linearization.make_var(x, difforder == self.WITH_METRIC)
icov = FieldAdapter(self._domain[self._icov], self._icov) res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].log().sum())
res0 = r.vdot(r*icov).real if difforder == self.VALUE_ONLY:
res1 = icov.log().sum() return Field.scalar(res)
res = (res0-res1).scale(0.5)(x) if difforder == self.WITH_JAC:
if not lin or not x.want_metric:
return res return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)} mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
metric = makeOp(MultiField.from_dict(mf)) return res.add_metric(makeOp(MultiField.from_dict(mf)))
return res.add_metric(SandwichOperator.make(x.jac, metric))
class GaussianEnergy(EnergyOperator): class GaussianEnergy(EnergyOperator):
...@@ -187,9 +184,10 @@ class GaussianEnergy(EnergyOperator): ...@@ -187,9 +184,10 @@ class GaussianEnergy(EnergyOperator):
self._mean = mean self._mean = mean
if inverse_covariance is None: if inverse_covariance is None:
self._op = Squared2NormOperator(self._domain).scale(0.5) self._op = Squared2NormOperator(self._domain).scale(0.5)
self._met = ScalingOperator(self._domain, 1)
else: else:
self._op = QuadraticFormOperator(inverse_covariance) self._op = QuadraticFormOperator(inverse_covariance)
self._icov = None if inverse_covariance is None else inverse_covariance self._met = inverse_covariance
def _checkEquivalence(self, newdom): def _checkEquivalence(self, newdom):
newdom = makeDomain(newdom) newdom = makeDomain(newdom)
...@@ -199,14 +197,15 @@ class GaussianEnergy(EnergyOperator): ...@@ -199,14 +197,15 @@ class GaussianEnergy(EnergyOperator):
if self._domain != newdom: if self._domain != newdom:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
residual = x if self._mean is None else x - self._mean residual = x if self._mean is None else x - self._mean
res = self._op(residual).real res = self._op(residual).real
if not isinstance(x, Linearization) or not x.want_metric: if difforder < self.WITH_METRIC:
return res return res
metric = SandwichOperator.make(x.jac, self._icov) return res.add_metric(self._met)
return res.add_metric(metric)
class PoissonianEnergy(EnergyOperator): class PoissonianEnergy(EnergyOperator):
...@@ -236,14 +235,16 @@ class PoissonianEnergy(EnergyOperator): ...@@ -236,14 +235,16 @@ class PoissonianEnergy(EnergyOperator):
self._d = d self._d = d
self._domain = DomainTuple.make(d.domain) self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
fa = FieldAdapter(self._domain, 'foo') if difforder >= self.WITH_JAC:
res = (fa.sum() - fa.log().vdot(self._d))(fa.adjoint(x)) x = Linearization.make_var(x, difforder == self.WITH_METRIC)
if not isinstance(x, Linearization) or not x.want_metric: res = x.sum() - x.log().vdot(self._d)
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
return res return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) return res.add_metric(makeOp(1./x.val))
return res.add_metric(metric)
class InverseGammaLikelihood(EnergyOperator): class InverseGammaLikelihood(EnergyOperator):
...@@ -278,15 +279,16 @@ class InverseGammaLikelihood(EnergyOperator): ...@@ -278,15 +279,16 @@ class InverseGammaLikelihood(EnergyOperator):
raise TypeError raise TypeError
self._alphap1 = alpha+1 self._alphap1 = alpha+1
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
fa = FieldAdapter(self._domain, 'foo') if difforder >= self.WITH_JAC:
x = fa.adjoint(x) x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = (fa.log().vdot(self._alphap1) + fa.one_over().vdot(self._beta))(x) res = x.log().vdot(self._alphap1) + x.one_over().vdot(self._beta)
if not isinstance(x, Linearization) or not x.want_metric: if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
return res return res
metric = SandwichOperator.make(x.jac, makeOp(self._alphap1/(x.val**2))) return res.add_metric(makeOp(self._alphap1/(x.val**2)))
return res.add_metric(metric)
class StudentTEnergy(EnergyOperator): class StudentTEnergy(EnergyOperator):
...@@ -310,16 +312,17 @@ class StudentTEnergy(EnergyOperator): ...@@ -310,16 +312,17 @@ class StudentTEnergy(EnergyOperator):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._theta = theta self._theta = theta
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
v = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum() if difforder >= self.WITH_JAC:
if not isinstance(x, Linearization): x = Linearization.make_var(x, difforder == self.WITH_METRIC)
return Field.scalar(v) res = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if not x.want_metric: if difforder == self.VALUE_ONLY:
return v return Field.scalar(res)
if difforder == self.WITH_JAC:
return res
met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3)) met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
met = SandwichOperator.make(x.jac, met) return res.add_metric(met)
return v.add_metric(met)
class BernoulliEnergy(EnergyOperator): class BernoulliEnergy(EnergyOperator):
...@@ -347,17 +350,18 @@ class BernoulliEnergy(EnergyOperator): ...@@ -347,17 +350,18 @@ class BernoulliEnergy(EnergyOperator):
self._d = d self._d = d
self._domain = DomainTuple.make(d.domain) self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
iden = FieldAdapter(self._domain, 'foo') if difforder >= self.WITH_JAC:
from .adder import Adder x = Linearization.make_var(x, difforder == self.WITH_METRIC)
v = -iden.log().vdot(self._d) + (Adder(1, domain=self._domain) @ iden.scale(-1)).log().vdot(self._d-1.) res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.)
v = v(iden.adjoint(x)) if difforder == self.VALUE_ONLY:
if not isinstance(x, Linearization) or not x.want_metric: return Field.scalar(res)
return v if difforder == self.WITH_JAC:
return res
met = makeOp(1./(x.val*(1. - x.val))) met = makeOp(1./(x.val*(1. - x.val)))
met = SandwichOperator.make(x.jac, met) met = SandwichOperator.make(x.jac, met)
return v.add_metric(met) return res.add_metric(met)
class StandardHamiltonian(EnergyOperator): class StandardHamiltonian(EnergyOperator):
...@@ -402,14 +406,14 @@ class StandardHamiltonian(EnergyOperator): ...@@ -402,14 +406,14 @@ class StandardHamiltonian(EnergyOperator):
self._ic_samp = ic_samp self._ic_samp = ic_samp
self._domain = lh.domain self._domain = lh.domain
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
if (self._ic_samp is None or not isinstance(x, Linearization) or not x.want_metric): if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
if difforder <= self.WITH_JAC or self._ic_samp is None:
return (self._lh + self._prior)(x) return (self._lh + self._prior)(x)
else: lhx, prx = self._lh(x), self._prior(x)
lhx, prx = self._lh(x), self._prior(x) return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
mtr = SamplingEnabler(lhx.metric, prx.metric, self._ic_samp)
return (lhx + prx).add_metric(mtr)
def __repr__(self): def __repr__(self):
subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__())) subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
...@@ -448,13 +452,9 @@ class AveragedEnergy(EnergyOperator): ...@@ -448,13 +452,9 @@ class AveragedEnergy(EnergyOperator):
self._domain = h.domain self._domain = h.domain
self._res_samples = tuple(res_samples) self._res_samples = tuple(res_samples)
def apply(self, x): def apply(self, x, difforder):
self._check_input(x) self._check_input(x)
if isinstance(self._domain, MultiDomain): if difforder >= self.WITH_JAC:
iden = ScalingOperator(self._domain, 1.) x = Linearization.make_var(x, difforder == self.WITH_METRIC)
else: mymap = map(lambda v: self._h(x+v), self._res_samples)
iden = FieldAdapter(self._domain, 'foo') return utilities.my_sum(mymap)/len(self._res_samples)
x = iden.adjoint(x)
from .adder import Adder