Commit f4c5baef authored by Martin Reinecke's avatar Martin Reinecke
Browse files

temporatu experiment

parent c8860db8
Pipeline #76492 failed with stages
in 4 minutes and 1 second
...@@ -331,6 +331,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable, ...@@ -331,6 +331,8 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
val0 = op(loc) val0 = op(loc)
_, op0 = op.simplify_for_constant_input(cstloc) _, op0 = op.simplify_for_constant_input(cstloc)
from .operators.simplify_for_const import ConstantReplacer
op0 = op0@ConstantReplacer(op0.domain, cstloc)
val1 = op0(loc) val1 = op0(loc)
# MR FIXME: This tests something we don't promise! # MR FIXME: This tests something we don't promise!
# val2 = op0(loc.unite(cstloc)) # val2 = op0(loc.unite(cstloc))
...@@ -79,6 +79,40 @@ class ConstantOperator(Operator): ...@@ -79,6 +79,40 @@ class ConstantOperator(Operator):
return f'{tgt} <- ConstantOperator <- {dom}' return f'{tgt} <- ConstantOperator <- {dom}'
# NOTE: this operator had domein == target (unlike ConstantOperator, where this is not necessarily the case)!
class ConstantReplacer(Operator):
def __init__(self, dom, output):
from ..sugar import makeDomain
from ..domain_tuple import DomainTuple
self._domain = self._target = makeDomain(dom)
self._output = output
if isinstance (self._domain, DomainTuple):
assert(self._domain is output.domain, "domain mismatch")
for k in output.keys():
assert (self._domain[k] is output.domain[k], "subdomain mismatch")
self._nonconst = makeDomain({key: subdom for key, subdom in self._domain.items() if key not in self._output.keys()})
def apply(self, x):
from .simple_linear_operators import NullOperator
from ..domain_tuple import DomainTuple
if isinstance (self._domain, DomainTuple):
out = self._output
if x.jac is not None:
out = x.val.extract_part(self._nonconst).unite(self._output)
out = x.extract_part(self._nonconst).unite(self._output)
if x.jac is not None:
return, NullOperator(self._domain, self._target))
return out
def __repr__(self):
dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
return f'{dom} <- ConstantReplacer <- {dom}'
class SlowPartialConstantOperator(Operator): class SlowPartialConstantOperator(Operator):
def __init__(self, domain, constant_keys): def __init__(self, domain, constant_keys):
from ..sugar import makeDomain from ..sugar import makeDomain
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment