Commit 997cdaef authored by Martin Reinecke's avatar Martin Reinecke
Browse files

make metric calculation optional

parent f30e88f9
...@@ -41,7 +41,7 @@ def _get_acceptable_location(op, loc, lin): ...@@ -41,7 +41,7 @@ def _get_acceptable_location(op, loc, lin):
for i in range(50): for i in range(50):
try: try:
loc2 = loc+dir loc2 = loc+dir
lin2 = op(Linearization.make_var(loc2)) lin2 = op(Linearization.make_var(loc2, lin.want_metric))
if np.isfinite(lin2.val.sum()) and abs(lin2.val.sum()) < 1e20: if np.isfinite(lin2.val.sum()) and abs(lin2.val.sum()) < 1e20:
break break
except FloatingPointError: except FloatingPointError:
...@@ -54,14 +54,14 @@ def _get_acceptable_location(op, loc, lin): ...@@ -54,14 +54,14 @@ def _get_acceptable_location(op, loc, lin):
def _check_consistency(op, loc, tol, ntries, do_metric): def _check_consistency(op, loc, tol, ntries, do_metric):
for _ in range(ntries): for _ in range(ntries):
lin = op(Linearization.make_var(loc)) lin = op(Linearization.make_var(loc, do_metric))
loc2, lin2 = _get_acceptable_location(op, loc, lin) loc2, lin2 = _get_acceptable_location(op, loc, lin)
dir = loc2-loc dir = loc2-loc
locnext = loc2 locnext = loc2
dirnorm = dir.norm() dirnorm = dir.norm()
for i in range(50): for i in range(50):
locmid = loc + 0.5*dir locmid = loc + 0.5*dir
linmid = op(Linearization.make_var(locmid)) linmid = op(Linearization.make_var(locmid, do_metric))
dirder = linmid.jac(dir) dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val) numgrad = (lin2.val-lin.val)
xtol = tol * dirder.norm() / np.sqrt(dirder.size) xtol = tol * dirder.norm() / np.sqrt(dirder.size)
......
...@@ -53,7 +53,7 @@ class InverseGammaModel(Operator): ...@@ -53,7 +53,7 @@ class InverseGammaModel(Operator):
outer = 1/outer_inv outer = 1/outer_inv
jac = makeOp(Field.from_local_data(self._domain, inner*outer)) jac = makeOp(Field.from_local_data(self._domain, inner*outer))
jac = jac(x.jac) jac = jac(x.jac)
return Linearization(points, jac) return x.new(points, jac)
@staticmethod @staticmethod
def IG(field, alpha, q): def IG(field, alpha, q):
......
...@@ -9,13 +9,17 @@ from .sugar import makeOp ...@@ -9,13 +9,17 @@ from .sugar import makeOp
class Linearization(object): class Linearization(object):
def __init__(self, val, jac, metric=None): def __init__(self, val, jac, metric=None, want_metric=False):
self._val = val self._val = val
self._jac = jac self._jac = jac
if self._val.domain != self._jac.target: if self._val.domain != self._jac.target:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
self._want_metric = want_metric
self._metric = metric self._metric = metric
def new(self, val, jac, metric=None):
return Linearization(val, jac, metric, self._want_metric)
@property @property
def domain(self): def domain(self):
return self._jac.domain return self._jac.domain
...@@ -37,6 +41,10 @@ class Linearization(object): ...@@ -37,6 +41,10 @@ class Linearization(object):
"""Only available if target is a scalar""" """Only available if target is a scalar"""
return self._jac.adjoint_times(Field.scalar(1.)) return self._jac.adjoint_times(Field.scalar(1.))
@property
def want_metric(self):
return self._want_metric
@property @property
def metric(self): def metric(self):
"""Only available if target is a scalar""" """Only available if target is a scalar"""
...@@ -44,35 +52,34 @@ class Linearization(object): ...@@ -44,35 +52,34 @@ class Linearization(object):
def __getitem__(self, name): def __getitem__(self, name):
from .operators.simple_linear_operators import FieldAdapter from .operators.simple_linear_operators import FieldAdapter
return Linearization(self._val[name], FieldAdapter(self.domain, name)) return self.new(self._val[name], FieldAdapter(self.domain, name))
def __neg__(self): def __neg__(self):
return Linearization( return self.new(-self._val, -self._jac,
-self._val, -self._jac,
None if self._metric is None else -self._metric) None if self._metric is None else -self._metric)
def conjugate(self): def conjugate(self):
return Linearization( return self.new(
self._val.conjugate(), self._jac.conjugate(), self._val.conjugate(), self._jac.conjugate(),
None if self._metric is None else self._metric.conjugate()) None if self._metric is None else self._metric.conjugate())
@property @property
def real(self): def real(self):
return Linearization(self._val.real, self._jac.real) return self.new(self._val.real, self._jac.real)
def _myadd(self, other, neg): def _myadd(self, other, neg):
if isinstance(other, Linearization): if isinstance(other, Linearization):
met = None met = None
if self._metric is not None and other._metric is not None: if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg) met = self._metric._myadd(other._metric, neg)
return Linearization( return self.new(
self._val.flexible_addsub(other._val, neg), self._val.flexible_addsub(other._val, neg),
self._jac._myadd(other._jac, neg), met) self._jac._myadd(other._jac, neg), met)
if isinstance(other, (int, float, complex, Field, MultiField)): if isinstance(other, (int, float, complex, Field, MultiField)):
if neg: if neg:
return Linearization(self._val-other, self._jac, self._metric) return self.new(self._val-other, self._jac, self._metric)
else: else:
return Linearization(self._val+other, self._jac, self._metric) return self.new(self._val+other, self._jac, self._metric)
def __add__(self, other): def __add__(self, other):
return self._myadd(other, False) return self._myadd(other, False)
...@@ -91,7 +98,7 @@ class Linearization(object): ...@@ -91,7 +98,7 @@ class Linearization(object):
if isinstance(other, Linearization): if isinstance(other, Linearization):
if self.target != other.target: if self.target != other.target:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
return Linearization( return self.new(
self._val*other._val, self._val*other._val,
(makeOp(other._val)(self._jac))._myadd( (makeOp(other._val)(self._jac))._myadd(
makeOp(self._val)(other._jac), False)) makeOp(self._val)(other._jac), False))
...@@ -99,11 +106,11 @@ class Linearization(object): ...@@ -99,11 +106,11 @@ class Linearization(object):
if other == 1: if other == 1:
return self return self
met = None if self._metric is None else self._metric.scale(other) met = None if self._metric is None else self._metric.scale(other)
return Linearization(self._val*other, self._jac.scale(other), met) return self.new(self._val*other, self._jac.scale(other), met)
if isinstance(other, (Field, MultiField)): if isinstance(other, (Field, MultiField)):
if self.target != other.domain: if self.target != other.domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
return Linearization(self._val*other, makeOp(other)(self._jac)) return self.new(self._val*other, makeOp(other)(self._jac))
def __rmul__(self, other): def __rmul__(self, other):
return self.__mul__(other) return self.__mul__(other)
...@@ -111,46 +118,48 @@ class Linearization(object): ...@@ -111,46 +118,48 @@ class Linearization(object):
def vdot(self, other): def vdot(self, other):
from .operators.simple_linear_operators import VdotOperator from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)): if isinstance(other, (Field, MultiField)):
return Linearization( return self.new(
Field.scalar(self._val.vdot(other)), Field.scalar(self._val.vdot(other)),
VdotOperator(other)(self._jac)) VdotOperator(other)(self._jac))
return Linearization( return self.new(
Field.scalar(self._val.vdot(other._val)), Field.scalar(self._val.vdot(other._val)),
VdotOperator(self._val)(other._jac) + VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac)) VdotOperator(other._val)(self._jac))
def sum(self): def sum(self):
from .operators.simple_linear_operators import SumReductionOperator from .operators.simple_linear_operators import SumReductionOperator
return Linearization( return self.new(
Field.scalar(self._val.sum()), Field.scalar(self._val.sum()),
SumReductionOperator(self._jac.target)(self._jac)) SumReductionOperator(self._jac.target)(self._jac))
def exp(self): def exp(self):
tmp = self._val.exp() tmp = self._val.exp()
return Linearization(tmp, makeOp(tmp)(self._jac)) return self.new(tmp, makeOp(tmp)(self._jac))
def log(self): def log(self):
tmp = self._val.log() tmp = self._val.log()
return Linearization(tmp, makeOp(1./self._val)(self._jac)) return self.new(tmp, makeOp(1./self._val)(self._jac))
def tanh(self): def tanh(self):
tmp = self._val.tanh() tmp = self._val.tanh()
return Linearization(tmp, makeOp(1.-tmp**2)(self._jac)) return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
def positive_tanh(self): def positive_tanh(self):
tmp = self._val.tanh() tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp) tmp2 = 0.5*(1.+tmp)
return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def add_metric(self, metric): def add_metric(self, metric):
return Linearization(self._val, self._jac, metric) return self.new(self._val, self._jac, metric)
@staticmethod @staticmethod
def make_var(field): def make_var(field, want_metric=False):
from .operators.scaling_operator import ScalingOperator from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(1., field.domain)) return Linearization(field, ScalingOperator(1., field.domain),
want_metric=want_metric)
@staticmethod @staticmethod
def make_const(field): def make_const(field, want_metric=False):
from .operators.simple_linear_operators import NullOperator from .operators.simple_linear_operators import NullOperator
return Linearization(field, NullOperator(field.domain, field.domain)) return Linearization(field, NullOperator(field.domain, field.domain),
want_metric=want_metric)
...@@ -8,17 +8,18 @@ from ..operators.scaling_operator import ScalingOperator ...@@ -8,17 +8,18 @@ from ..operators.scaling_operator import ScalingOperator
class EnergyAdapter(Energy): class EnergyAdapter(Energy):
def __init__(self, position, op, constants=[]): def __init__(self, position, op, constants=[], want_metric=False):
super(EnergyAdapter, self).__init__(position) super(EnergyAdapter, self).__init__(position)
self._op = op self._op = op
self._constants = constants self._constants = constants
if len(self._constants) == 0: if len(self._constants) == 0:
tmp = self._op(Linearization.make_var(self._position)) tmp = self._op(Linearization.make_var(self._position, want_metric))
else: else:
ops = [ScalingOperator(0. if key in self._constants else 1., dom) ops = [ScalingOperator(0. if key in self._constants else 1., dom)
for key, dom in self._position.domain.items()] for key, dom in self._position.domain.items()]
bdop = BlockDiagonalOperator(self._position.domain, tuple(ops)) bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
tmp = self._op(Linearization(self._position, bdop)) tmp = self._op(Linearization(self._position, bdop,
want_metric=want_metric))
self._val = tmp.val.local_data[()] self._val = tmp.val.local_data[()]
self._grad = tmp.gradient self._grad = tmp.gradient
self._metric = tmp._metric self._metric = tmp._metric
......
...@@ -9,22 +9,24 @@ from .. import utilities ...@@ -9,22 +9,24 @@ from .. import utilities
class KL_Energy(Energy): class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[], _samples=None): def __init__(self, position, h, nsamp, constants=[], _samples=None,
want_metric=False):
super(KL_Energy, self).__init__(position) super(KL_Energy, self).__init__(position)
self._h = h self._h = h
self._constants = constants self._constants = constants
self._want_metric = want_metric
if _samples is None: if _samples is None:
met = h(Linearization.make_var(position)).metric met = h(Linearization.make_var(position, True)).metric
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp)) for _ in range(nsamp))
self._samples = _samples self._samples = _samples
if len(constants) == 0: if len(constants) == 0:
tmp = Linearization.make_var(position) tmp = Linearization.make_var(position, want_metric)
else: else:
ops = [ScalingOperator(0. if key in constants else 1., dom) ops = [ScalingOperator(0. if key in constants else 1., dom)
for key, dom in position.domain.items()] for key, dom in position.domain.items()]
bdop = BlockDiagonalOperator(position.domain, tuple(ops)) bdop = BlockDiagonalOperator(position.domain, tuple(ops))
tmp = Linearization(position, bdop) tmp = Linearization(position, bdop, want_metric=want_metric)
mymap = map(lambda v: self._h(tmp+v), self._samples) mymap = map(lambda v: self._h(tmp+v), self._samples)
tmp = utilities.my_sum(mymap) * (1./len(self._samples)) tmp = utilities.my_sum(mymap) * (1./len(self._samples))
self._val = tmp.val.local_data[()] self._val = tmp.val.local_data[()]
...@@ -32,7 +34,8 @@ class KL_Energy(Energy): ...@@ -32,7 +34,8 @@ class KL_Energy(Energy):
self._metric = tmp.metric self._metric = tmp.metric
def at(self, position): def at(self, position):
return KL_Energy(position, self._h, 0, self._constants, self._samples) return KL_Energy(position, self._h, 0, self._constants, self._samples,
self._want_metric)
@property @property
def value(self): def value(self):
......
...@@ -42,7 +42,7 @@ class SquaredNormOperator(EnergyOperator): ...@@ -42,7 +42,7 @@ class SquaredNormOperator(EnergyOperator):
if isinstance(x, Linearization): if isinstance(x, Linearization):
val = Field.scalar(x.val.vdot(x.val)) val = Field.scalar(x.val.vdot(x.val))
jac = VdotOperator(2*x.val)(x.jac) jac = VdotOperator(2*x.val)(x.jac)
return Linearization(val, jac) return x.new(val, jac)
return Field.scalar(x.vdot(x)) return Field.scalar(x.vdot(x))
...@@ -59,7 +59,7 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -59,7 +59,7 @@ class QuadraticFormOperator(EnergyOperator):
t1 = self._op(x.val) t1 = self._op(x.val)
jac = VdotOperator(t1)(x.jac) jac = VdotOperator(t1)(x.jac)
val = Field.scalar(0.5*x.val.vdot(t1)) val = Field.scalar(0.5*x.val.vdot(t1))
return Linearization(val, jac) return x.new(val, jac)
return Field.scalar(0.5*x.vdot(self._op(x))) return Field.scalar(0.5*x.vdot(self._op(x)))
...@@ -91,7 +91,7 @@ class GaussianEnergy(EnergyOperator): ...@@ -91,7 +91,7 @@ class GaussianEnergy(EnergyOperator):
def apply(self, x): def apply(self, x):
residual = x if self._mean is None else x-self._mean residual = x if self._mean is None else x-self._mean
res = self._op(residual).real res = self._op(residual).real
if not isinstance(x, Linearization): if not isinstance(x, Linearization) or not x.want_metric:
return res return res
metric = SandwichOperator.make(x.jac, self._icov) metric = SandwichOperator.make(x.jac, self._icov)
return res.add_metric(metric) return res.add_metric(metric)
...@@ -107,6 +107,8 @@ class PoissonianEnergy(EnergyOperator): ...@@ -107,6 +107,8 @@ class PoissonianEnergy(EnergyOperator):
res = x.sum() - x.log().vdot(self._d) res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(res) return Field.scalar(res)
if not x.want_metric:
return res
metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
return res.add_metric(metric) return res.add_metric(metric)
...@@ -122,6 +124,8 @@ class BernoulliEnergy(EnergyOperator): ...@@ -122,6 +124,8 @@ class BernoulliEnergy(EnergyOperator):
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(v) return Field.scalar(v)
if not x.want_metric:
return v
met = makeOp(1./(x.val*(1.-x.val))) met = makeOp(1./(x.val*(1.-x.val)))
met = SandwichOperator.make(x.jac, met) met = SandwichOperator.make(x.jac, met)
return v.add_metric(met) return v.add_metric(met)
...@@ -135,11 +139,11 @@ class Hamiltonian(EnergyOperator): ...@@ -135,11 +139,11 @@ class Hamiltonian(EnergyOperator):
self._domain = lh.domain self._domain = lh.domain
def apply(self, x): def apply(self, x):
if self._ic_samp is None or not isinstance(x, Linearization): if (self._ic_samp is None or not isinstance(x, Linearization) or
not x.want_metric):
return self._lh(x)+self._prior(x) return self._lh(x)+self._prior(x)
else: else:
lhx = self._lh(x) lhx, prx = self._lh(x), self._prior(x)
prx = self._prior(x)
mtr = SamplingEnabler(lhx.metric, prx.metric.inverse, mtr = SamplingEnabler(lhx.metric, prx.metric.inverse,
self._ic_samp, prx.metric.inverse) self._ic_samp, prx.metric.inverse)
return (lhx+prx).add_metric(mtr) return (lhx+prx).add_metric(mtr)
......
...@@ -175,7 +175,7 @@ class LinearOperator(Operator): ...@@ -175,7 +175,7 @@ class LinearOperator(Operator):
return self.apply(x, self.TIMES) return self.apply(x, self.TIMES)
from ..linearization import Linearization from ..linearization import Linearization
if isinstance(x, Linearization): if isinstance(x, Linearization):
return Linearization(self(x._val), self(x._jac)) return x.new(self(x._val), self(x._jac))
return self.__matmul__(x) return self.__matmul__(x)
def times(self, x): def times(self, x):
......
...@@ -144,11 +144,12 @@ class _OpProd(Operator): ...@@ -144,11 +144,12 @@ class _OpProd(Operator):
v2 = v.extract(self._op2.domain) v2 = v.extract(self._op2.domain)
if not lin: if not lin:
return self._op1(v1) * self._op2(v2) return self._op1(v1) * self._op2(v2)
lin1 = self._op1(Linearization.make_var(v1)) wm = x.want_metric
lin2 = self._op2(Linearization.make_var(v2)) lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
op = (makeOp(lin1._val)(lin2._jac))._myadd( op = (makeOp(lin1._val)(lin2._jac))._myadd(
makeOp(lin2._val)(lin1._jac), False) makeOp(lin2._val)(lin1._jac), False)
return Linearization(lin1._val*lin2._val, op(x.jac)) return lin1.new(lin1._val*lin2._val, op(x.jac))
class _OpSum(Operator): class _OpSum(Operator):
...@@ -168,10 +169,11 @@ class _OpSum(Operator): ...@@ -168,10 +169,11 @@ class _OpSum(Operator):
res = None res = None
if not lin: if not lin:
return self._op1(v1).unite(self._op2(v2)) return self._op1(v1).unite(self._op2(v2))
lin1 = self._op1(Linearization.make_var(v1)) wm = x.want_metric
lin2 = self._op2(Linearization.make_var(v2)) lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
op = lin1._jac._myadd(lin2._jac, False) op = lin1._jac._myadd(lin2._jac, False)
res = Linearization(lin1._val+lin2._val, op(x.jac)) res = lin1.new(lin1._val+lin2._val, op(x.jac))
if lin1._metric is not None and lin2._metric is not None: if lin1._metric is not None and lin2._metric is not None:
res = res.add_metric(lin1._metric + lin2._metric) res = res.add_metric(lin1._metric + lin2._metric)
return res return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment