Commit a8f5fcd1 authored by Philipp Arras's avatar Philipp Arras

New apply logic

parent a778ee76
Pipeline #70576 failed with stages
in 13 minutes and 4 seconds
......@@ -15,7 +15,11 @@ from .multi_domain import MultiDomain
from .field import Field
from .multi_field import MultiField
from .linearization import Linearization
from .operators.operator import Operator
from .operators.linear_operator import LinearOperator
from .operators.adder import Adder
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
......@@ -28,7 +32,6 @@ from .operators.harmonic_operators import (
HarmonicSmoothingOperator)
from .operators.field_zero_padder import FieldZeroPadder
from .operators.inversion_enabler import InversionEnabler
from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator
from .operators.regridding_operator import RegriddingOperator
from .operators.sampling_enabler import SamplingEnabler
......@@ -87,8 +90,6 @@ from .utilities import memo, frozendict
from .logger import logger
from .linearization import Linearization
from .operator_spectrum import operator_spectrum
# We deliberately don't set __all__ here, because we don't want people to do a
......
......@@ -114,7 +114,7 @@ def _actual_domain_check_nonlinear(op, loc):
assert_(reslin.jac.target is reslin.target)
_actual_domain_check_linear(reslin.jac, inp=loc)
_actual_domain_check_linear(reslin.jac.adjoint, inp=reslin.jac(loc))
if wm:
if reslin.metric is not None:
assert_(reslin.metric.domain is reslin.metric.target)
assert_(reslin.metric.domain is op.domain)
......@@ -153,7 +153,7 @@ def _performance_check(op, pos, raise_on_fail):
cond.append(cop.count != 3)
lin.jac.adjoint(lin.val)
cond.append(cop.count != 4)
if wm and myop.target is DomainTuple.scalar_domain():
if lin.metric is not None:
lin.metric(pos)
cond.append(cop.count != 6)
if any(cond):
......
......@@ -25,6 +25,7 @@ from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..logger import logger
from ..multi_field import MultiField
from ..operators.adder import Adder
......@@ -221,14 +222,15 @@ class _Normalization(Operator):
self._mode_multiplicity = makeOp(makeField(self._domain, mode_multiplicity))
self._specsum = _SpecialSum(self._domain, space)
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
fa = FieldAdapter(self._domain, 'foo')
amp = fa.exp()
spec = (2*fa).exp()
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
amp = x.exp()
spec = (2*x).exp()
# FIXME This normalizes also the zeromode which is supposed to be left
# untouched by this operator
return (self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp)(fa.adjoint(x))
return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp
class _SpecialSum(EndomorphicOperator):
......
......@@ -34,17 +34,8 @@ def _float_or_listoffloat(inp):
return [float(x) for x in inp] if isinstance(inp, list) else float(inp)
def _make_dynamic_operator(target,
harmonic_padding,
sm_s0,
sm_x0,
cone,
keys,
causal,
minimum_phase,
sigc=None,
quant=None,
codomain=None):
def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, causal,
minimum_phase, sigc=None, quant=None, codomain=None):
if not isinstance(target, RGSpace):
raise TypeError("RGSpace required")
if not target.harmonic:
......@@ -128,7 +119,7 @@ def _make_dynamic_operator(target,
scaling = DiagonalOperator(Field(c.target, scaling))
ops['lightspeed'] = scaling(lightspeed)
c = LightConeOperator(c.target, m.target, quant)(c.exp())
c = LightConeOperator(c.target, m.target, quant) @ c.exp()
ops['light_cone'] = c
m = c*m
......@@ -139,13 +130,7 @@ def _make_dynamic_operator(target,
return m, ops
def dynamic_operator(*,
target,
harmonic_padding,
sm_s0,
sm_x0,
key,
causal=True,
def dynamic_operator(*, target, harmonic_padding, sm_s0, sm_x0, key, causal=True,
minimum_phase=False):
"""Constructs an operator encoding the Green's function of a linear
homogeneous dynamic system.
......@@ -206,17 +191,8 @@ def dynamic_operator(*,
return _make_dynamic_operator(**dct)
def dynamic_lightcone_operator(*,
target,
harmonic_padding,
sm_s0,
sm_x0,
key,
lightcone_key,
sigc,
quant,
causal=True,
minimum_phase=False):
def dynamic_lightcone_operator(*, target, harmonic_padding, sm_s0, sm_x0, key, lightcone_key,
sigc, quant, causal=True, minimum_phase=False):
'''Extends the functionality of :func:`dynamic_operator` to a Green's
function which is constrained to be within a light cone.
......
......@@ -131,12 +131,10 @@ class LightConeOperator(Operator):
self._target = DomainTuple.make(target)
self._sigx = sigx
def apply(self, x):
islin = isinstance(x, Linearization)
val = x.val.val if islin else x.val
a, derivs = _cone_arrays(val, self.target, self._sigx, islin)
def apply(self, x, difforder):
a, derivs = _cone_arrays(x.val, self.target, self._sigx, difforder >= self.WITH_JAC)
res = Field(self.target, a)
if not islin:
if difforder == self.VALUE_ONLY:
return res
jac = _LightConeDerivative(x.jac.target, self.target, derivs)(x.jac)
return Linearization(res, jac, want_metric=x.want_metric)
jac = _LightConeDerivative(self._domain, self._target, derivs)
return Linearization(res, jac)
......@@ -38,19 +38,17 @@ class _InterpolationOperator(Operator):
self._deriv = (self._table[1:]-self._table[:-1]) / self._d
self._inv_table_func = inverse_table_func
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
lin = isinstance(x, Linearization)
val = x.val.val if lin else x.val
val = (np.clip(val, self._xmin, self._xmax) - self._xmin) / self._d
val = (np.clip(x.val, self._xmin, self._xmax) - self._xmin) / self._d
fi = np.floor(val).astype(int)
w = val - fi
res = self._inv_table_func((1-w)*self._table[fi] + w*self._table[fi+1])
resfld = Field(self._domain, res)
if not lin:
if difforder == self.VALUE_ONLY:
return resfld
jac = makeOp(Field(self._domain, self._deriv[fi]*res)) @ x.jac
return x.new(resfld, jac)
jac = makeOp(Field(self._domain, self._deriv[fi]*res))
return Linearization(resfld, jac)
def InverseGammaOperator(domain, alpha, q, delta=0.001):
......
......@@ -63,6 +63,13 @@ class Linearization(object):
"""
return Linearization(val, jac, metric, self._want_metric)
def prepend_jac(self, jac):
metric = None
if self._metric is not None:
from .operators.sandwich_operator import SandwichOperator
metric = None if self._metric is None else SandwichOperator.make(jac, self._metric)
return self.new(self._val, self._jac @ jac, metric)
@property
def domain(self):
"""DomainTuple or MultiDomain : the Jacobian's domain"""
......
......@@ -18,9 +18,10 @@
import numpy as np
from ..field import Field
from ..linearization import Linearization
from ..multi_field import MultiField
from .operator import Operator
from ..sugar import makeDomain
from .operator import Operator
class Adder(Operator):
......@@ -42,8 +43,10 @@ class Adder(Operator):
self._domain = self._target = dom
self._neg = bool(neg)
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
if self._neg:
return x - self._a
return x + self._a
......@@ -58,13 +58,13 @@ class Squared2NormOperator(EnergyOperator):
def __init__(self, domain):
self._domain = domain
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
if isinstance(x, Linearization):
val = Field.scalar(x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac)
return x.new(val, jac)
return Field.scalar(x.vdot(x))
res = Field.scalar(x.vdot(x))
if difforder == self.VALUE_ONLY:
return res
jac = VdotOperator(2*x)
return Linearization(res, jac, want_metric=difforder == self.WITH_METRIC)
class QuadraticFormOperator(EnergyOperator):
......@@ -87,14 +87,13 @@ class QuadraticFormOperator(EnergyOperator):
self._op = endo
self._domain = endo.domain
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
if isinstance(x, Linearization):
t1 = self._op(x.val)
jac = VdotOperator(t1)(x.jac)
val = Field.scalar(0.5*x.val.vdot(t1))
return x.new(val, jac)
return Field.scalar(0.5*x.vdot(self._op(x)))
t1 = self._op(x)
res = Field.scalar(0.5*x.vdot(t1))
if difforder == self.VALUE_ONLY:
return res
return Linearization(res, VdotOperator(t1))
class VariableCovarianceGaussianEnergy(EnergyOperator):
......@@ -128,19 +127,17 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
lin = isinstance(x, Linearization)
r = FieldAdapter(self._domain[self._r], self._r)
icov = FieldAdapter(self._domain[self._icov], self._icov)
res0 = r.vdot(r*icov).real
res1 = icov.log().sum()
res = (res0-res1).scale(0.5)(x)
if not lin or not x.want_metric:
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].log().sum())
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
metric = makeOp(MultiField.from_dict(mf))
return res.add_metric(SandwichOperator.make(x.jac, metric))
return res.add_metric(makeOp(MultiField.from_dict(mf)))
class GaussianEnergy(EnergyOperator):
......@@ -187,9 +184,10 @@ class GaussianEnergy(EnergyOperator):
self._mean = mean
if inverse_covariance is None:
self._op = Squared2NormOperator(self._domain).scale(0.5)
self._met = ScalingOperator(self._domain, 1)
else:
self._op = QuadraticFormOperator(inverse_covariance)
self._icov = None if inverse_covariance is None else inverse_covariance
self._met = inverse_covariance
def _checkEquivalence(self, newdom):
newdom = makeDomain(newdom)
......@@ -199,14 +197,15 @@ class GaussianEnergy(EnergyOperator):
if self._domain != newdom:
raise ValueError("domain mismatch")
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
residual = x if self._mean is None else x - self._mean
res = self._op(residual).real
if not isinstance(x, Linearization) or not x.want_metric:
if difforder < self.WITH_METRIC:
return res
metric = SandwichOperator.make(x.jac, self._icov)
return res.add_metric(metric)
return res.add_metric(self._met)
class PoissonianEnergy(EnergyOperator):
......@@ -236,14 +235,16 @@ class PoissonianEnergy(EnergyOperator):
self._d = d
self._domain = DomainTuple.make(d.domain)
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
fa = FieldAdapter(self._domain, 'foo')
res = (fa.sum() - fa.log().vdot(self._d))(fa.adjoint(x))
if not isinstance(x, Linearization) or not x.want_metric:
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = x.sum() - x.log().vdot(self._d)
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric)
return res.add_metric(makeOp(1./x.val))
class InverseGammaLikelihood(EnergyOperator):
......@@ -278,15 +279,16 @@ class InverseGammaLikelihood(EnergyOperator):
raise TypeError
self._alphap1 = alpha+1
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
fa = FieldAdapter(self._domain, 'foo')
x = fa.adjoint(x)
res = (fa.log().vdot(self._alphap1) + fa.one_over().vdot(self._beta))(x)
if not isinstance(x, Linearization) or not x.want_metric:
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = x.log().vdot(self._alphap1) + x.one_over().vdot(self._beta)
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
return res
metric = SandwichOperator.make(x.jac, makeOp(self._alphap1/(x.val**2)))
return res.add_metric(metric)
return res.add_metric(makeOp(self._alphap1/(x.val**2)))
class StudentTEnergy(EnergyOperator):
......@@ -310,16 +312,17 @@ class StudentTEnergy(EnergyOperator):
self._domain = DomainTuple.make(domain)
self._theta = theta
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
v = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
return v
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
return res
met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
met = SandwichOperator.make(x.jac, met)
return v.add_metric(met)
return res.add_metric(met)
class BernoulliEnergy(EnergyOperator):
......@@ -347,17 +350,18 @@ class BernoulliEnergy(EnergyOperator):
self._d = d
self._domain = DomainTuple.make(d.domain)
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
iden = FieldAdapter(self._domain, 'foo')
from .adder import Adder
v = -iden.log().vdot(self._d) + (Adder(1, domain=self._domain) @ iden.scale(-1)).log().vdot(self._d-1.)
v = v(iden.adjoint(x))
if not isinstance(x, Linearization) or not x.want_metric:
return v
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.)
if difforder == self.VALUE_ONLY:
return Field.scalar(res)
if difforder == self.WITH_JAC:
return res
met = makeOp(1./(x.val*(1. - x.val)))
met = SandwichOperator.make(x.jac, met)
return v.add_metric(met)
return res.add_metric(met)
class StandardHamiltonian(EnergyOperator):
......@@ -402,14 +406,14 @@ class StandardHamiltonian(EnergyOperator):
self._ic_samp = ic_samp
self._domain = lh.domain
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
if (self._ic_samp is None or not isinstance(x, Linearization) or not x.want_metric):
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
if difforder <= self.WITH_JAC or self._ic_samp is None:
return (self._lh + self._prior)(x)
else:
lhx, prx = self._lh(x), self._prior(x)
mtr = SamplingEnabler(lhx.metric, prx.metric, self._ic_samp)
return (lhx + prx).add_metric(mtr)
lhx, prx = self._lh(x), self._prior(x)
return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
def __repr__(self):
subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
......@@ -448,13 +452,9 @@ class AveragedEnergy(EnergyOperator):
self._domain = h.domain
self._res_samples = tuple(res_samples)
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
if isinstance(self._domain, MultiDomain):
iden = ScalingOperator(self._domain, 1.)
else:
iden = FieldAdapter(self._domain, 'foo')
x = iden.adjoint(x)
from .adder import Adder
mymap = map(lambda v: self._h(Adder(v) @ iden), self._res_samples)
return utilities.my_sum(mymap).scale(1./len(self._res_samples))(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap)/len(self._res_samples)
......@@ -174,7 +174,7 @@ class LinearOperator(Operator):
return self.apply(x, self.TIMES)
from ..linearization import Linearization
if isinstance(x, Linearization):
return x.new(self(x._val), self(x._jac))
return x.new(self(x._val), self).prepend_jac(x.jac)
return self@x
def times(self, x):
......
......@@ -16,6 +16,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..field import Field
from ..multi_field import MultiField
from ..utilities import NiftyMeta, indent
......@@ -24,6 +27,10 @@ class Operator(metaclass=NiftyMeta):
domain, and can also provide the Jacobian.
"""
VALUE_ONLY = 0
WITH_JAC = 1
WITH_METRIC = 2
@property
def domain(self):
"""The domain on which the Operator's input Field is defined.
......@@ -159,7 +166,7 @@ class Operator(metaclass=NiftyMeta):
return self
return _OpChain.make((_Clipper(self.target, min, max), self))
def apply(self, x):
def apply(self, x, difforder):
"""Applies the operator to a Field or MultiField.
Parameters
......@@ -176,22 +183,28 @@ class Operator(metaclass=NiftyMeta):
return self.apply(x.extract(self.domain))
def _check_input(self, x):
from ..linearization import Linearization
d = x.target if isinstance(x, Linearization) else x.domain
self._check_domain_equality(self._domain, d)
if not isinstance(x, (Field, MultiField)):
raise TypeError
self._check_domain_equality(self._domain, x.domain)
def __call__(self, x):
if isinstance(x, Operator):
return _OpChain.make((self, x))
return self.apply(x)
from ..linearization import Linearization
from ..field import Field
from ..multi_field import MultiField
if isinstance(x, Linearization):
difforder = self.WITH_METRIC if x.want_metric else self.WITH_JAC
return self.apply(x.val, difforder).prepend_jac(x.jac)
elif isinstance(x, (Field, MultiField)):
return self.apply(x, self.VALUE_ONLY)
raise TypeError('Operator can only consume Field, MultiFields and Linearizations')
def ducktape(self, name):
from .simple_linear_operators import ducktape
return self(ducktape(self, None, name))
return self @ ducktape(self, None, name)
def ducktape_left(self, name):
from .simple_linear_operators import ducktape
return ducktape(None, self, name)(self)
return ducktape(None, self, name) @ self
def __repr__(self):
return self.__class__.__name__
......@@ -266,19 +279,13 @@ class _ConstantOperator(Operator):
self._target = output.domain
self._output = output
def apply(self, x):
def apply(self, x, difforder):
from ..linearization import Linearization
from .simple_linear_operators import NullOperator
from ..domain_tuple import DomainTuple
self._check_input(x)
if not isinstance(x, Linearization):
return self._output
if x.want_metric and self._target is DomainTuple.scalar_domain():
met = NullOperator(self._domain, self._domain)
else:
met = None
return x.new(self._output, NullOperator(self._domain, self._target),
met)
if difforder >= self.WITH_JAC:
return Linearization(self._output, NullOperator(self._domain, self._target))
return self._output
def __repr__(self):
return 'ConstantOperator <- {}'.format(self.domain.keys())
......@@ -290,8 +297,11 @@ class _FunctionApplier(Operator):
self._domain = self._target = makeDomain(domain)
self._funcname = funcname
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
from ..linearization import Linearization
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
return getattr(x, self._funcname)()
......@@ -302,8 +312,11 @@ class _Clipper(Operator):
self._min = min
self._max = max
def apply(self, x):
def apply(self, x, difforder):
self._check_input(x)
from ..linearization import Linearization
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
return x.clip(self._min, self._max)