Commit 092bf7fd authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 6/n

parent 9dea1d88
......@@ -344,7 +344,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
assert oplin.jac.target is oplin0.jac.target
rndinp = from_random(oplin.jac.target)
assert_equal(oplin.jac.adjoint(rndinp).extract(varloc.domain), oplin0.jac.adjoint(rndinp))
assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain),
oplin0.jac.adjoint(rndinp), 1e-13, 1e-13)
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
assert_equal(foo, 0*foo)
......@@ -352,7 +353,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin.metric.draw_sample()
assert op0.domain is varloc.domain
_jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, only_r_differentiable)
_jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries,
only_r_differentiable)
def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
......
......@@ -138,13 +138,12 @@ class ChainOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain):
# return None, self
# newop = None
# for op in reversed(self._ops):
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
# newop = t_op if newop is None else op(newop)
# return c_inp, newop
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
......@@ -175,26 +175,25 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from .simplify_for_const import ConstantEnergyOperator
# assert len(c_inp.keys()) == 1
# key = c_inp.keys()[0]
# assert key in self._domain.keys()
# cst = c_inp[key]
# if key == self._kr:
# res = _SpecialGammaEnergy(cst).ducktape(self._ki)
# else:
# dt = self._dt[self._kr]
# res = GaussianEnergy(inverse_covariance=makeOp(cst),
# sampling_dtype=dt).ducktape(self._kr)
# trlog = cst.log().sum().val_rw()
# if not _iscomplex(dt):
# trlog /= 2
# res = res + ConstantEnergyOperator(res.domain, -trlog)
# res = res + ConstantEnergyOperator(self._domain, 0.)
# assert res.domain is self.domain
# assert res.target is self.target
# return None, res
def _simplify_for_constant_input_nontrivial(self, c_inp):
from .simplify_for_const import ConstantEnergyOperator
assert len(c_inp.keys()) == 1
key = c_inp.keys()[0]
assert key in self._domain.keys()
cst = c_inp[key]
if key == self._kr:
res = _SpecialGammaEnergy(cst).ducktape(self._ki)
else:
dt = self._dt[self._kr]
res = GaussianEnergy(inverse_covariance=makeOp(cst),
sampling_dtype=dt).ducktape(self._kr)
trlog = cst.log().sum().val_rw()
if not _iscomplex(dt):
trlog /= 2
res = res + ConstantEnergyOperator(-trlog)
res = res + ConstantEnergyOperator(0.)
assert res.target is self.target
return None, res
class _SpecialGammaEnergy(EnergyOperator):
......
......@@ -371,16 +371,15 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain):
# return None, self
# newop = None
# for op in reversed(self._ops):
# c_inp, t_op = op.simplify_for_constant_input(c_inp)
# newop = t_op if newop is None else op(newop)
# return c_inp, newop
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
newop = None
for op in reversed(self._ops):
c_inp, t_op = op.simplify_for_constant_input(c_inp)
newop = t_op if newop is None else op(newop)
return c_inp, newop
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
......@@ -413,20 +412,19 @@ class _OpProd(Operator):
jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, jac)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector
# f1, o1 = self._op1.simplify_for_constant_input(
# c_inp.extract_part(self._op1.domain))
# f2, o2 = self._op2.simplify_for_constant_input(
# c_inp.extract_part(self._op2.domain))
# if not isinstance(self._target, MultiDomain):
# return None, _OpProd(o1, o2)
# cc = ConstCollector()
# cc.mult(f1, o1.target)
# cc.mult(f2, o2.target)
# return cc.constfield, _OpProd(o1, o2)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
if not isinstance(self._target, MultiDomain):
return None, _OpProd(o1, o2)
cc = ConstCollector()
cc.mult(f1, o1.target)
cc.mult(f2, o2.target)
return cc.constfield, _OpProd(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
......@@ -459,20 +457,19 @@ class _OpSum(Operator):
res = res.add_metric(lin1._metric._myadd(lin2._metric, False))
return res
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector
# f1, o1 = self._op1.simplify_for_constant_input(
# c_inp.extract_part(self._op1.domain))
# f2, o2 = self._op2.simplify_for_constant_input(
# c_inp.extract_part(self._op2.domain))
# if not isinstance(self._target, MultiDomain):
# return None, _OpSum(o1, o2)
# cc = ConstCollector()
# cc.add(f1, o1.target)
# cc.add(f2, o2.target)
# return cc.constfield, _OpSum(o1, o2)
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
c_inp.extract_part(self._op2.domain))
if not isinstance(self._target, MultiDomain):
return None, _OpSum(o1, o2)
cc = ConstCollector()
cc.add(f1, o1.target)
cc.add(f2, o2.target)
return cc.constfield, _OpSum(o1, o2)
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
......
......@@ -207,28 +207,28 @@ class SumOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "SumOperator:\n"+indent(subs)
# def _simplify_for_constant_input_nontrivial(self, c_inp):
# f = []
# o = []
# for op in self._ops:
# tf, to = op.simplify_for_constant_input(
# c_inp.extract_part(op.domain))
# f.append(tf)
# o.append(to)
# from ..multi_domain import MultiDomain
# if not isinstance(self._target, MultiDomain):
# fullop = None
# for to, n in zip(o, self._neg):
# op = to if not n else -to
# fullop = op if fullop is None else fullop + op
# return None, fullop
# from .simplify_for_const import ConstCollector
# cc = ConstCollector()
# fullop = None
# for tf, to, n in zip(f, o, self._neg):
# cc.add(tf, to.target)
# op = to if not n else -to
# fullop = op if fullop is None else fullop + op
# return cc.constfield, fullop
def _simplify_for_constant_input_nontrivial(self, c_inp):
f = []
o = []
for op in self._ops:
tf, to = op.simplify_for_constant_input(
c_inp.extract_part(op.domain))
f.append(tf)
o.append(to)
from ..multi_domain import MultiDomain
if not isinstance(self._target, MultiDomain):
fullop = None
for to, n in zip(o, self._neg):
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return None, fullop
from .simplify_for_const import ConstCollector
cc = ConstCollector()
fullop = None
for tf, to, n in zip(f, o, self._neg):
cc.add(tf, to.target)
op = to if not n else -to
fullop = op if fullop is None else fullop + op
return cc.constfield, fullop
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment