There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

Commit 1c7629b0 authored by Philipp Arras's avatar Philipp Arras

Port energies

parent 2a489298
Pipeline #70612 failed with stages
in 11 minutes and 28 seconds
......@@ -222,10 +222,8 @@ class _Normalization(Operator):
self._mode_multiplicity = makeOp(makeField(self._domain, mode_multiplicity))
self._specsum = _SpecialSum(self._domain, space)
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
amp = x.exp()
spec = (2*x).exp()
# FIXME This normalizes also the zeromode which is supposed to be left
......
......@@ -131,10 +131,11 @@ class LightConeOperator(Operator):
self._target = DomainTuple.make(target)
self._sigx = sigx
def apply(self, x, difforder):
a, derivs = _cone_arrays(x.val, self.target, self._sigx, difforder >= self.WITH_JAC)
def apply(self, x):
lin = isinstance(x, Linearization)
a, derivs = _cone_arrays(x.val, self.target, self._sigx, lin)
res = Field(self.target, a)
if difforder == self.VALUE_ONLY:
if not lin:
return res
jac = _LightConeDerivative(self._domain, self._target, derivs)
return Linearization(res, jac)
......@@ -38,14 +38,14 @@ class _InterpolationOperator(Operator):
self._deriv = (self._table[1:]-self._table[:-1]) / self._d
self._inv_table_func = inverse_table_func
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
val = (np.clip(x.val, self._xmin, self._xmax) - self._xmin) / self._d
fi = np.floor(val).astype(int)
w = val - fi
res = self._inv_table_func((1-w)*self._table[fi] + w*self._table[fi+1])
resfld = Field(self._domain, res)
if difforder == self.VALUE_ONLY:
if not isinstance(x, Linearization):
return resfld
jac = makeOp(Field(self._domain, self._deriv[fi]*res))
return Linearization(resfld, jac)
......
......@@ -63,6 +63,9 @@ class Linearization(object):
"""
return Linearization(val, jac, metric, self._want_metric)
def trivial_jac(self):
return Linearization.make_var(self._val, self._want_metric)
def prepend_jac(self, jac):
metric = None
if self._metric is not None:
......
......@@ -18,7 +18,6 @@
import numpy as np
from ..field import Field
from ..linearization import Linearization
from ..multi_field import MultiField
from ..sugar import makeDomain
from .operator import Operator
......@@ -43,10 +42,8 @@ class Adder(Operator):
self._domain = self._target = dom
self._neg = bool(neg)
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
if self._neg:
return x - self._a
return x + self._a
......@@ -27,7 +27,6 @@ from ..sugar import makeDomain, makeOp
from .linear_operator import LinearOperator
from .operator import Operator
from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator
......@@ -58,13 +57,13 @@ class Squared2NormOperator(EnergyOperator):
def __init__(self, domain):
self._domain = domain
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
res = x.vdot(x)
if difforder == self.VALUE_ONLY:
if not isinstance(x, Linearization):
res = x.vdot(x)
return res
jac = VdotOperator(2*x)
return Linearization(res, jac, want_metric=difforder == self.WITH_METRIC)
res = x.val.vdot(x.val)
return x.new(res, VdotOperator(2*x.val))
class QuadraticFormOperator(EnergyOperator):
......@@ -87,13 +86,12 @@ class QuadraticFormOperator(EnergyOperator):
self._op = endo
self._domain = endo.domain
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
t1 = self._op(x)
res = 0.5*x.vdot(t1)
if difforder == self.VALUE_ONLY:
return res
return Linearization(res, VdotOperator(t1))
if not isinstance(x, Linearization):
return 0.5*x.vdot(self._op(x))
res = 0.5*x.val.vdot(self._op(x.val))
return x.new(res, VdotOperator(self._op(x.val)))
class VariableCovarianceGaussianEnergy(EnergyOperator):
......@@ -127,12 +125,10 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
dom = DomainTuple.make(domain)
self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].log().sum())
if difforder <= self.WITH_JAC:
if not isinstance(x, Linearization) or not x.want_metric:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
return res.add_metric(makeOp(MultiField.from_dict(mf)))
......@@ -195,15 +191,13 @@ class GaussianEnergy(EnergyOperator):
if self._domain != newdom:
raise ValueError("domain mismatch")
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
residual = x if self._mean is None else x - self._mean
res = self._op(residual).real
if difforder < self.WITH_METRIC:
return res
return res.add_metric(self._met)
if isinstance(x, Linearization) and x.want_metric:
return res.add_metric(self._met)
return res
class PoissonianEnergy(EnergyOperator):
......@@ -233,12 +227,10 @@ class PoissonianEnergy(EnergyOperator):
self._d = d
self._domain = DomainTuple.make(d.domain)
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = x.sum() - x.log().vdot(self._d)
if difforder <= self.WITH_JAC:
if not isinstance(x, Linearization) or not x.want_metric:
return res
return res.add_metric(makeOp(1./x.val))
......@@ -275,12 +267,10 @@ class InverseGammaLikelihood(EnergyOperator):
raise TypeError
self._alphap1 = alpha+1
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = x.log().vdot(self._alphap1) + x.one_over().vdot(self._beta)
if difforder <= self.WITH_JAC:
if not isinstance(x, Linearization) or not x.want_metric:
return res
return res.add_metric(makeOp(self._alphap1/(x.val**2)))
......@@ -306,12 +296,10 @@ class StudentTEnergy(EnergyOperator):
self._domain = DomainTuple.make(domain)
self._theta = theta
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if difforder <= self.WITH_JAC:
if not isinstance(x, Linearization) or not x.want_metric:
return res
met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
return res.add_metric(met)
......@@ -342,16 +330,12 @@ class BernoulliEnergy(EnergyOperator):
self._d = d
self._domain = DomainTuple.make(d.domain)
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.)
if difforder <= self.WITH_JAC:
if not isinstance(x, Linearization) or not x.want_metric:
return res
met = makeOp(1./(x.val*(1. - x.val)))
met = SandwichOperator.make(x.jac, met)
return res.add_metric(met)
return res.add_metric(makeOp(1./(x.val*(1. - x.val))))
class StandardHamiltonian(EnergyOperator):
......@@ -396,11 +380,9 @@ class StandardHamiltonian(EnergyOperator):
self._ic_samp = ic_samp
self._domain = lh.domain
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
if difforder <= self.WITH_JAC or self._ic_samp is None:
if not isinstance(x, Linearization) or not x.want_metric or self._ic_samp is None:
return (self._lh + self._prior)(x)
lhx, prx = self._lh(x), self._prior(x)
return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
......@@ -442,9 +424,7 @@ class AveragedEnergy(EnergyOperator):
self._domain = h.domain
self._res_samples = tuple(res_samples)
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap)/len(self._res_samples)
......@@ -166,7 +166,7 @@ class Operator(metaclass=NiftyMeta):
return self
return _OpChain.make((_Clipper(self.target, min, max), self))
def apply(self, x, difforder):
def apply(self, x):
"""Applies the operator to a Field or MultiField.
Parameters
......@@ -183,7 +183,8 @@ class Operator(metaclass=NiftyMeta):
return self.apply(x.extract(self.domain), 0)
def _check_input(self, x):
if not isinstance(x, (Field, MultiField)):
from ..linearization import Linearization
if not isinstance(x, (Field, MultiField, Linearization)):
raise TypeError
self._check_domain_equality(self._domain, x.domain)
......@@ -192,10 +193,9 @@ class Operator(metaclass=NiftyMeta):
from ..field import Field
from ..multi_field import MultiField
if isinstance(x, Linearization):
difforder = self.WITH_METRIC if x.want_metric else self.WITH_JAC
return self.apply(x.val, difforder).prepend_jac(x.jac)
return self.apply(x.trivial_jac()).prepend_jac(x.jac)
elif isinstance(x, (Field, MultiField)):
return self.apply(x, self.VALUE_ONLY)
return self.apply(x)
raise TypeError('Operator can only consume Field, MultiFields and Linearizations')
def ducktape(self, name):
......@@ -279,12 +279,12 @@ class _ConstantOperator(Operator):
self._target = output.domain
self._output = output
def apply(self, x, difforder):
def apply(self, x):
from ..linearization import Linearization
from .simple_linear_operators import NullOperator
self._check_input(x)
if difforder >= self.WITH_JAC:
return Linearization(self._output, NullOperator(self._domain, self._target))
if isinstance(x, Linearization):
return x.new(self._output, NullOperator(self._domain, self._target))
return self._output
def __repr__(self):
......@@ -297,11 +297,8 @@ class _FunctionApplier(Operator):
self._domain = self._target = makeDomain(domain)
self._funcname = funcname
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
from ..linearization import Linearization
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
return getattr(x, self._funcname)()
......@@ -312,11 +309,8 @@ class _Clipper(Operator):
self._min = min
self._max = max
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
from ..linearization import Linearization
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
return x.clip(self._min, self._max)
......@@ -326,11 +320,8 @@ class _PowerOp(Operator):
self._domain = self._target = makeDomain(domain)
self._power = power
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
from ..linearization import Linearization
if difforder >= self.WITH_JAC:
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
return x**self._power
......@@ -366,11 +357,8 @@ class _OpChain(_CombinedOperator):
if self._ops[i-1].domain != self._ops[i].target:
raise ValueError("domain mismatch")
def apply(self, x, difforder):
def apply(self, x):
self._check_input(x)
if difforder >= self.WITH_JAC:
from ..linearization import Linearization
x = Linearization.make_var(x, difforder == self.WITH_METRIC)
for op in reversed(self._ops):
x = op(x)
return x
......@@ -401,15 +389,17 @@ class _OpProd(Operator):
self._op1 = op1
self._op2 = op2
def apply(self, x, difforder):
def apply(self, x):
from ..linearization import Linearization
from ..sugar import makeOp
self._check_input(x)
lin = isinstance(x, Linearization)
wm = x.want_metric if lin else None
x = x.val if lin else x
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
if difforder == self.VALUE_ONLY:
if not lin:
return self._op1(v1) * self._op2(v2)
wm = difforder == self.WITH_METRIC
lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False)
......@@ -443,14 +433,16 @@ class _OpSum(Operator):
self._op1 = op1
self._op2 = op2
def apply(self, x, difforder):
def apply(self, x):
from ..linearization import Linearization
self._check_input(x)
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
if difforder == self.VALUE_ONLY:
if not isinstance(x, Linearization):
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
return self._op1(v1).unite(self._op2(v2))
wm = difforder == self.WITH_METRIC
v1 = x.val.extract(self._op1.domain)
v2 = x.val.extract(self._op2.domain)
wm = x.want_metric
lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
op = lin1._jac._myadd(lin2._jac, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment