Commit f24e26e9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more

parent 14daede2
......@@ -19,7 +19,6 @@ import numpy as np
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..operators.linear_operator import LinearOperator
from ..operators.operator import Operator
......@@ -132,7 +131,7 @@ class LightConeOperator(Operator):
self._sigx = sigx
def apply(self, x):
lin = isinstance(x, Linearization)
lin = x.jac is not None
a, derivs = _cone_arrays(x.val.val if lin else x.val, self.target, self._sigx, lin)
res = Field(self.target, a)
if not lin:
......
......@@ -22,7 +22,6 @@ from scipy.interpolate import CubicSpline
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
from ..sugar import makeOp
from .. import random
......@@ -79,7 +78,7 @@ class _InterpolationOperator(Operator):
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
lin = x.jac is not None
xval = x.val.val if lin else x.val
res = self._interpolator(xval)
res = Field(self._domain, res)
......@@ -148,7 +147,7 @@ class UniformOperator(Operator):
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
lin = x.jac is not None
xval = x.val.val if lin else x.val
res = Field(self._target, self._scale*norm._cdf(xval) + self._loc)
if not lin:
......
......@@ -20,7 +20,6 @@ import numpy as np
from .. import utilities
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
......@@ -59,9 +58,8 @@ class Squared2NormOperator(EnergyOperator):
def apply(self, x):
self._check_input(x)
if not isinstance(x, Linearization):
res = x.vdot(x)
return res
if x.jac is None:
return x.vdot(x)
res = x.val.vdot(x.val)
return x.new(res, VdotOperator(2*x.val))
......@@ -88,7 +86,7 @@ class QuadraticFormOperator(EnergyOperator):
def apply(self, x):
self._check_input(x)
if not isinstance(x, Linearization):
if x.jac is None:
return 0.5*x.vdot(self._op(x))
res = 0.5*x.val.vdot(self._op(x.val))
return x.new(res, VdotOperator(self._op(x.val)))
......@@ -128,7 +126,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum())
if not isinstance(x, Linearization) or not x.want_metric:
if not x.want_metric:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
return res.add_metric(makeOp(MultiField.from_dict(mf)))
......@@ -195,7 +193,7 @@ class GaussianEnergy(EnergyOperator):
self._check_input(x)
residual = x if self._mean is None else x - self._mean
res = self._op(residual).real
if isinstance(x, Linearization) and x.want_metric:
if x.want_metric:
return res.add_metric(self._met)
return res
......@@ -230,7 +228,7 @@ class PoissonianEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = x.sum() - x.ptw("log").vdot(self._d)
if not isinstance(x, Linearization) or not x.want_metric:
if not x.want_metric:
return res
return res.add_metric(makeOp(1./x.val))
......@@ -270,7 +268,7 @@ class InverseGammaLikelihood(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
if not isinstance(x, Linearization) or not x.want_metric:
if not x.want_metric:
return res
return res.add_metric(makeOp(self._alphap1/(x.val**2)))
......@@ -299,7 +297,7 @@ class StudentTEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = ((self._theta+1)/2)*(x**2/self._theta).ptw("log1p").sum()
if not isinstance(x, Linearization) or not x.want_metric:
if not x.want_metric:
return res
met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
return res.add_metric(met)
......@@ -333,7 +331,7 @@ class BernoulliEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
if not isinstance(x, Linearization) or not x.want_metric:
if not x.want_metric:
return res
return res.add_metric(makeOp(1./(x.val*(1. - x.val))))
......@@ -382,7 +380,7 @@ class StandardHamiltonian(EnergyOperator):
def apply(self, x):
self._check_input(x)
if not isinstance(x, Linearization) or not x.want_metric or self._ic_samp is None:
if not x.want_metric or self._ic_samp is None:
return (self._lh + self._prior)(x)
lhx, prx = self._lh(x), self._prior(x)
return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
......
......@@ -171,10 +171,10 @@ class LinearOperator(Operator):
def __call__(self, x):
"""Same as :meth:`times`"""
from ..linearization import Linearization
if isinstance(x, (Field, MultiField)):
return self.apply(x, self.TIMES)
if isinstance(x, Linearization):
if x.jac is not None:
return x.new(self(x._val), self).prepend_jac(x.jac)
if x.val is not None:
return self.apply(x, self.TIMES)
return self@x
def times(self, x):
......
......@@ -329,10 +329,9 @@ class _ConstantOperator(Operator):
self._output = output
def apply(self, x):
from ..linearization import Linearization
from .simple_linear_operators import NullOperator
self._check_input(x)
if isinstance(x, Linearization):
if x.jac is not None:
return x.new(self._output, NullOperator(self._domain, self._target))
return self._output
......@@ -421,8 +420,8 @@ class _OpProd(Operator):
from ..linearization import Linearization
from ..sugar import makeOp
self._check_input(x)
lin = isinstance(x, Linearization)
wm = x.want_metric if lin else None
lin = x.jac is not None
wm = x.want_metric if lin else False
x = x.val if lin else x
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
......@@ -464,7 +463,7 @@ class _OpSum(Operator):
def apply(self, x):
from ..linearization import Linearization
self._check_input(x)
if not isinstance(x, Linearization):
if x.jac is None:
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
return self._op1(v1).unite(self._op2(v2))
......
......@@ -99,8 +99,7 @@ class ScalingOperator(EndomorphicOperator):
def __call__(self, other):
res = EndomorphicOperator.__call__(self, other)
if np.isreal(self._factor) and self._factor >= 0:
from ..linearization import Linearization
if isinstance(other, Linearization):
if other.jac is not None:
if other.metric is not None:
from .sandwich_operator import SandwichOperator
sqrt_fac = np.sqrt(self._factor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment