diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index 8144df76aa34d1a46b77c2be2537a62d7287d38c..c68343a58c65ca050e9206f7d23eb4b8f6c0640f 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -222,10 +222,8 @@ class _Normalization(Operator): self._mode_multiplicity = makeOp(makeField(self._domain, mode_multiplicity)) self._specsum = _SpecialSum(self._domain, space) - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) amp = x.exp() spec = (2*x).exp() # FIXME This normalizes also the zeromode which is supposed to be left diff --git a/nifty6/library/light_cone_operator.py b/nifty6/library/light_cone_operator.py index 3bd0eaaa9ec9f1561a13a0829aa3223f3b6d823a..4bb040886c84ae0219b65a4caed8855a92a8fafb 100644 --- a/nifty6/library/light_cone_operator.py +++ b/nifty6/library/light_cone_operator.py @@ -131,10 +131,11 @@ class LightConeOperator(Operator): self._target = DomainTuple.make(target) self._sigx = sigx - def apply(self, x, difforder): - a, derivs = _cone_arrays(x.val, self.target, self._sigx, difforder >= self.WITH_JAC) + def apply(self, x): + lin = isinstance(x, Linearization) + a, derivs = _cone_arrays(x.val, self.target, self._sigx, lin) res = Field(self.target, a) - if difforder == self.VALUE_ONLY: + if not lin: return res jac = _LightConeDerivative(self._domain, self._target, derivs) return Linearization(res, jac) diff --git a/nifty6/library/special_distributions.py b/nifty6/library/special_distributions.py index 7515366fbd360bdc208fa0c62e35d8cbe55d12e6..97fa7d34036c9565c061509b78401ad6956819a4 100644 --- a/nifty6/library/special_distributions.py +++ b/nifty6/library/special_distributions.py @@ -38,14 +38,14 @@ class _InterpolationOperator(Operator): self._deriv = (self._table[1:]-self._table[:-1]) / self._d self._inv_table_func = inverse_table_func - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) val = (np.clip(x.val, self._xmin, self._xmax) - self._xmin) / self._d fi = np.floor(val).astype(int) w = val - fi res = self._inv_table_func((1-w)*self._table[fi] + w*self._table[fi+1]) resfld = Field(self._domain, res) - if difforder == self.VALUE_ONLY: + if not isinstance(x, Linearization): return resfld jac = makeOp(Field(self._domain, self._deriv[fi]*res)) return Linearization(resfld, jac) diff --git a/nifty6/linearization.py b/nifty6/linearization.py index bbc54c9be00e0c6a934637fa802a1c495c68f50d..fdd23a0b46b03449e47fcd4458b2a8d3235f7882 100644 --- a/nifty6/linearization.py +++ b/nifty6/linearization.py @@ -63,6 +63,9 @@ class Linearization(object): """ return Linearization(val, jac, metric, self._want_metric) + def trivial_jac(self): + return Linearization.make_var(self._val, self._want_metric) + def prepend_jac(self, jac): metric = None if self._metric is not None: diff --git a/nifty6/operators/adder.py b/nifty6/operators/adder.py index 92fac3afcf414d321a0bb4a0da218133d6af2ddf..6e9f0aece191ca9ead3454879a160476e97883fb 100644 --- a/nifty6/operators/adder.py +++ b/nifty6/operators/adder.py @@ -18,7 +18,6 @@ import numpy as np from ..field import Field -from ..linearization import Linearization from ..multi_field import MultiField from ..sugar import makeDomain from .operator import Operator @@ -43,10 +42,8 @@ class Adder(Operator): self._domain = self._target = dom self._neg = bool(neg) - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) if self._neg: return x - self._a return x + self._a diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py index a7b4dbeca4a6a1655ee93d43a39165f874ce5258..28d08b4b22c0ea315e3216658d264254f43e0dfc 100644 --- a/nifty6/operators/energy_operators.py +++ b/nifty6/operators/energy_operators.py @@ -27,7 +27,6 @@ from ..sugar import makeDomain, makeOp from .linear_operator import LinearOperator from .operator import Operator from .sampling_enabler import SamplingEnabler -from .sandwich_operator import SandwichOperator from .scaling_operator import ScalingOperator from .simple_linear_operators import VdotOperator @@ -58,13 +57,13 @@ class Squared2NormOperator(EnergyOperator): def __init__(self, domain): self._domain = domain - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - res = x.vdot(x) - if difforder == self.VALUE_ONLY: + if not isinstance(x, Linearization): + res = x.vdot(x) return res - jac = VdotOperator(2*x) - return Linearization(res, jac, want_metric=difforder == self.WITH_METRIC) + res = x.val.vdot(x.val) + return x.new(res, VdotOperator(2*x.val)) class QuadraticFormOperator(EnergyOperator): @@ -87,13 +86,12 @@ class QuadraticFormOperator(EnergyOperator): self._op = endo self._domain = endo.domain - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - t1 = self._op(x) - res = 0.5*x.vdot(t1) - if difforder == self.VALUE_ONLY: - return res - return Linearization(res, VdotOperator(t1)) + if not isinstance(x, Linearization): + return 0.5*x.vdot(self._op(x)) + res = 0.5*x.val.vdot(self._op(x.val)) + return x.new(res, VdotOperator(self._op(x.val))) class VariableCovarianceGaussianEnergy(EnergyOperator): @@ -127,12 +125,10 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): dom = DomainTuple.make(domain) self._domain = MultiDomain.make({self._r: dom, self._icov: dom}) - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].log().sum()) - if difforder <= self.WITH_JAC: + if not isinstance(x, Linearization) or not x.want_metric: return res mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)} return res.add_metric(makeOp(MultiField.from_dict(mf))) @@ -195,15 +191,13 @@ class GaussianEnergy(EnergyOperator): if self._domain != newdom: raise ValueError("domain mismatch") - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) residual = x if self._mean is None else x - self._mean res = self._op(residual).real - if difforder < self.WITH_METRIC: - return res - return res.add_metric(self._met) + if isinstance(x, Linearization) and x.want_metric: + return res.add_metric(self._met) + return res class PoissonianEnergy(EnergyOperator): @@ -233,12 +227,10 @@ class PoissonianEnergy(EnergyOperator): self._d = d self._domain = DomainTuple.make(d.domain) - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) res = x.sum() - x.log().vdot(self._d) - if difforder <= self.WITH_JAC: + if not isinstance(x, Linearization) or not x.want_metric: return res return res.add_metric(makeOp(1./x.val)) @@ -275,12 +267,10 @@ class InverseGammaLikelihood(EnergyOperator): raise TypeError self._alphap1 = alpha+1 - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) res = x.log().vdot(self._alphap1) + x.one_over().vdot(self._beta) - if difforder <= self.WITH_JAC: + if not isinstance(x, Linearization) or not x.want_metric: return res return res.add_metric(makeOp(self._alphap1/(x.val**2))) @@ -306,12 +296,10 @@ class StudentTEnergy(EnergyOperator): self._domain = DomainTuple.make(domain) self._theta = theta - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) res = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum() - if difforder <= self.WITH_JAC: + if not isinstance(x, Linearization) or not x.want_metric: return res met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3)) return res.add_metric(met) @@ -342,16 +330,12 @@ class BernoulliEnergy(EnergyOperator): self._d = d self._domain = DomainTuple.make(d.domain) - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.) - if difforder <= self.WITH_JAC: + if not isinstance(x, Linearization) or not x.want_metric: return res - met = makeOp(1./(x.val*(1. - x.val))) - met = SandwichOperator.make(x.jac, met) - return res.add_metric(met) + return res.add_metric(makeOp(1./(x.val*(1. - x.val)))) class StandardHamiltonian(EnergyOperator): @@ -396,11 +380,9 @@ class StandardHamiltonian(EnergyOperator): self._ic_samp = ic_samp self._domain = lh.domain - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) - if difforder <= self.WITH_JAC or self._ic_samp is None: + if not isinstance(x, Linearization) or not x.want_metric or self._ic_samp is None: return (self._lh + self._prior)(x) lhx, prx = self._lh(x), self._prior(x) return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp)) @@ -442,9 +424,7 @@ class AveragedEnergy(EnergyOperator): self._domain = h.domain self._res_samples = tuple(res_samples) - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) mymap = map(lambda v: self._h(x+v), self._res_samples) return utilities.my_sum(mymap)/len(self._res_samples) diff --git a/nifty6/operators/operator.py b/nifty6/operators/operator.py index 41ed0157be194bfd5210318026140b807aee98b9..a2b509444488a98d5ef2c48486d8da47e9448169 100644 --- a/nifty6/operators/operator.py +++ b/nifty6/operators/operator.py @@ -166,7 +166,7 @@ class Operator(metaclass=NiftyMeta): return self return _OpChain.make((_Clipper(self.target, min, max), self)) - def apply(self, x, difforder): + def apply(self, x): """Applies the operator to a Field or MultiField. Parameters @@ -183,7 +183,8 @@ class Operator(metaclass=NiftyMeta): return self.apply(x.extract(self.domain), 0) def _check_input(self, x): - if not isinstance(x, (Field, MultiField)): + from ..linearization import Linearization + if not isinstance(x, (Field, MultiField, Linearization)): raise TypeError self._check_domain_equality(self._domain, x.domain) @@ -192,10 +193,9 @@ class Operator(metaclass=NiftyMeta): from ..field import Field from ..multi_field import MultiField if isinstance(x, Linearization): - difforder = self.WITH_METRIC if x.want_metric else self.WITH_JAC - return self.apply(x.val, difforder).prepend_jac(x.jac) + return self.apply(x.trivial_jac()).prepend_jac(x.jac) elif isinstance(x, (Field, MultiField)): - return self.apply(x, self.VALUE_ONLY) + return self.apply(x) raise TypeError('Operator can only consume Field, MultiFields and Linearizations') def ducktape(self, name): @@ -279,12 +279,12 @@ class _ConstantOperator(Operator): self._target = output.domain self._output = output - def apply(self, x, difforder): + def apply(self, x): from ..linearization import Linearization from .simple_linear_operators import NullOperator self._check_input(x) - if difforder >= self.WITH_JAC: - return Linearization(self._output, NullOperator(self._domain, self._target)) + if isinstance(x, Linearization): + return x.new(self._output, NullOperator(self._domain, self._target)) return self._output def __repr__(self): @@ -297,11 +297,8 @@ class _FunctionApplier(Operator): self._domain = self._target = makeDomain(domain) self._funcname = funcname - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - from ..linearization import Linearization - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) return getattr(x, self._funcname)() @@ -312,11 +309,8 @@ class _Clipper(Operator): self._min = min self._max = max - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - from ..linearization import Linearization - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) return x.clip(self._min, self._max) @@ -326,11 +320,8 @@ class _PowerOp(Operator): self._domain = self._target = makeDomain(domain) self._power = power - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - from ..linearization import Linearization - if difforder >= self.WITH_JAC: - x = Linearization.make_var(x, difforder == self.WITH_METRIC) return x**self._power @@ -366,11 +357,8 @@ class _OpChain(_CombinedOperator): if self._ops[i-1].domain != self._ops[i].target: raise ValueError("domain mismatch") - def apply(self, x, difforder): + def apply(self, x): self._check_input(x) - if difforder >= self.WITH_JAC: - from ..linearization import Linearization - x = Linearization.make_var(x, difforder == self.WITH_METRIC) for op in reversed(self._ops): x = op(x) return x @@ -401,15 +389,17 @@ class _OpProd(Operator): self._op1 = op1 self._op2 = op2 - def apply(self, x, difforder): + def apply(self, x): from ..linearization import Linearization from ..sugar import makeOp self._check_input(x) + lin = isinstance(x, Linearization) + wm = x.want_metric if lin else None + x = x.val if lin else x v1 = x.extract(self._op1.domain) v2 = x.extract(self._op2.domain) - if difforder == self.VALUE_ONLY: + if not lin: return self._op1(v1) * self._op2(v2) - wm = difforder == self.WITH_METRIC lin1 = self._op1(Linearization.make_var(v1, wm)) lin2 = self._op2(Linearization.make_var(v2, wm)) jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False) @@ -443,14 +433,16 @@ class _OpSum(Operator): self._op1 = op1 self._op2 = op2 - def apply(self, x, difforder): + def apply(self, x): from ..linearization import Linearization self._check_input(x) - v1 = x.extract(self._op1.domain) - v2 = x.extract(self._op2.domain) - if difforder == self.VALUE_ONLY: + if not isinstance(x, Linearization): + v1 = x.extract(self._op1.domain) + v2 = x.extract(self._op2.domain) return self._op1(v1).unite(self._op2(v2)) - wm = difforder == self.WITH_METRIC + v1 = x.val.extract(self._op1.domain) + v2 = x.val.extract(self._op2.domain) + wm = x.want_metric lin1 = self._op1(Linearization.make_var(v1, wm)) lin2 = self._op2(Linearization.make_var(v2, wm)) op = lin1._jac._myadd(lin2._jac, False)