Commit 092bf7fd authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement proper constant support 6/n

parent 9dea1d88
...@@ -344,7 +344,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable, ...@@ -344,7 +344,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
assert oplin.jac.target is oplin0.jac.target assert oplin.jac.target is oplin0.jac.target
rndinp = from_random(oplin.jac.target) rndinp = from_random(oplin.jac.target)
assert_equal(oplin.jac.adjoint(rndinp).extract(varloc.domain), oplin0.jac.adjoint(rndinp)) assert_allclose(oplin.jac.adjoint(rndinp).extract(varloc.domain),
oplin0.jac.adjoint(rndinp), 1e-13, 1e-13)
foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain) foo = oplin.jac.adjoint(rndinp).extract(cstloc.domain)
assert_equal(foo, 0*foo) assert_equal(foo, 0*foo)
...@@ -352,7 +353,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable, ...@@ -352,7 +353,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin.metric.draw_sample() oplin.metric.draw_sample()
assert op0.domain is varloc.domain assert op0.domain is varloc.domain
_jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries, only_r_differentiable) _jac_vs_finite_differences(op0, varloc, np.sqrt(tol), ntries,
only_r_differentiable)
def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable): def _jac_vs_finite_differences(op, loc, tol, ntries, only_r_differentiable):
......
...@@ -138,13 +138,12 @@ class ChainOperator(LinearOperator): ...@@ -138,13 +138,12 @@ class ChainOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops) subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs) return "ChainOperator:\n" + utilities.indent(subs)
# def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain): if not isinstance(self._domain, MultiDomain):
# return None, self return None, self
newop = None
# newop = None for op in reversed(self._ops):
# for op in reversed(self._ops): c_inp, t_op = op.simplify_for_constant_input(c_inp)
# c_inp, t_op = op.simplify_for_constant_input(c_inp) newop = t_op if newop is None else op(newop)
# newop = t_op if newop is None else op(newop) return c_inp, newop
# return c_inp, newop
...@@ -175,26 +175,25 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): ...@@ -175,26 +175,25 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)}) met = MultiField.from_dict({self._kr: i.val, self._ki: met**(-2)})
return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt)) return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
# def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
# from .simplify_for_const import ConstantEnergyOperator from .simplify_for_const import ConstantEnergyOperator
# assert len(c_inp.keys()) == 1 assert len(c_inp.keys()) == 1
# key = c_inp.keys()[0] key = c_inp.keys()[0]
# assert key in self._domain.keys() assert key in self._domain.keys()
# cst = c_inp[key] cst = c_inp[key]
# if key == self._kr: if key == self._kr:
# res = _SpecialGammaEnergy(cst).ducktape(self._ki) res = _SpecialGammaEnergy(cst).ducktape(self._ki)
# else: else:
# dt = self._dt[self._kr] dt = self._dt[self._kr]
# res = GaussianEnergy(inverse_covariance=makeOp(cst), res = GaussianEnergy(inverse_covariance=makeOp(cst),
# sampling_dtype=dt).ducktape(self._kr) sampling_dtype=dt).ducktape(self._kr)
# trlog = cst.log().sum().val_rw() trlog = cst.log().sum().val_rw()
# if not _iscomplex(dt): if not _iscomplex(dt):
# trlog /= 2 trlog /= 2
# res = res + ConstantEnergyOperator(res.domain, -trlog) res = res + ConstantEnergyOperator(-trlog)
# res = res + ConstantEnergyOperator(self._domain, 0.) res = res + ConstantEnergyOperator(0.)
# assert res.domain is self.domain assert res.target is self.target
# assert res.target is self.target return None, res
# return None, res
class _SpecialGammaEnergy(EnergyOperator): class _SpecialGammaEnergy(EnergyOperator):
......
...@@ -371,16 +371,15 @@ class _OpChain(_CombinedOperator): ...@@ -371,16 +371,15 @@ class _OpChain(_CombinedOperator):
x = op(x) x = op(x)
return x return x
# def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
# if not isinstance(self._domain, MultiDomain): if not isinstance(self._domain, MultiDomain):
# return None, self return None, self
newop = None
# newop = None for op in reversed(self._ops):
# for op in reversed(self._ops): c_inp, t_op = op.simplify_for_constant_input(c_inp)
# c_inp, t_op = op.simplify_for_constant_input(c_inp) newop = t_op if newop is None else op(newop)
# newop = t_op if newop is None else op(newop) return c_inp, newop
# return c_inp, newop
def __repr__(self): def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops) subs = "\n".join(sub.__repr__() for sub in self._ops)
...@@ -413,20 +412,19 @@ class _OpProd(Operator): ...@@ -413,20 +412,19 @@ class _OpProd(Operator):
jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False) jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, jac) return lin1.new(lin1._val*lin2._val, jac)
# def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
# f1, o1 = self._op1.simplify_for_constant_input( c_inp.extract_part(self._op1.domain))
# c_inp.extract_part(self._op1.domain)) f2, o2 = self._op2.simplify_for_constant_input(
# f2, o2 = self._op2.simplify_for_constant_input( c_inp.extract_part(self._op2.domain))
# c_inp.extract_part(self._op2.domain)) if not isinstance(self._target, MultiDomain):
# if not isinstance(self._target, MultiDomain): return None, _OpProd(o1, o2)
# return None, _OpProd(o1, o2) cc = ConstCollector()
# cc = ConstCollector() cc.mult(f1, o1.target)
# cc.mult(f1, o1.target) cc.mult(f2, o2.target)
# cc.mult(f2, o2.target) return cc.constfield, _OpProd(o1, o2)
# return cc.constfield, _OpProd(o1, o2)
def __repr__(self): def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2)) subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
...@@ -459,20 +457,19 @@ class _OpSum(Operator): ...@@ -459,20 +457,19 @@ class _OpSum(Operator):
res = res.add_metric(lin1._metric._myadd(lin2._metric, False)) res = res.add_metric(lin1._metric._myadd(lin2._metric, False))
return res return res
# def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
# from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
# from .simplify_for_const import ConstCollector from .simplify_for_const import ConstCollector
f1, o1 = self._op1.simplify_for_constant_input(
# f1, o1 = self._op1.simplify_for_constant_input( c_inp.extract_part(self._op1.domain))
# c_inp.extract_part(self._op1.domain)) f2, o2 = self._op2.simplify_for_constant_input(
# f2, o2 = self._op2.simplify_for_constant_input( c_inp.extract_part(self._op2.domain))
# c_inp.extract_part(self._op2.domain)) if not isinstance(self._target, MultiDomain):
# if not isinstance(self._target, MultiDomain): return None, _OpSum(o1, o2)
# return None, _OpSum(o1, o2) cc = ConstCollector()
# cc = ConstCollector() cc.add(f1, o1.target)
# cc.add(f1, o1.target) cc.add(f2, o2.target)
# cc.add(f2, o2.target) return cc.constfield, _OpSum(o1, o2)
# return cc.constfield, _OpSum(o1, o2)
def __repr__(self): def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2)) subs = "\n".join(sub.__repr__() for sub in (self._op1, self._op2))
......
...@@ -207,28 +207,28 @@ class SumOperator(LinearOperator): ...@@ -207,28 +207,28 @@ class SumOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops) subs = "\n".join(sub.__repr__() for sub in self._ops)
return "SumOperator:\n"+indent(subs) return "SumOperator:\n"+indent(subs)
# def _simplify_for_constant_input_nontrivial(self, c_inp): def _simplify_for_constant_input_nontrivial(self, c_inp):
# f = [] f = []
# o = [] o = []
# for op in self._ops: for op in self._ops:
# tf, to = op.simplify_for_constant_input( tf, to = op.simplify_for_constant_input(
# c_inp.extract_part(op.domain)) c_inp.extract_part(op.domain))
# f.append(tf) f.append(tf)
# o.append(to) o.append(to)
# from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
# if not isinstance(self._target, MultiDomain): if not isinstance(self._target, MultiDomain):
# fullop = None fullop = None
# for to, n in zip(o, self._neg): for to, n in zip(o, self._neg):
# op = to if not n else -to op = to if not n else -to
# fullop = op if fullop is None else fullop + op fullop = op if fullop is None else fullop + op
# return None, fullop return None, fullop
# from .simplify_for_const import ConstCollector from .simplify_for_const import ConstCollector
# cc = ConstCollector() cc = ConstCollector()
# fullop = None fullop = None
# for tf, to, n in zip(f, o, self._neg): for tf, to, n in zip(f, o, self._neg):
# cc.add(tf, to.target) cc.add(tf, to.target)
# op = to if not n else -to op = to if not n else -to
# fullop = op if fullop is None else fullop + op fullop = op if fullop is None else fullop + op
# return cc.constfield, fullop return cc.constfield, fullop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment