Skip to content
Snippets Groups Projects
Commit c1b69ae1 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add ConstantOperator2

parent 17e3b653
No related branches found
No related tags found
1 merge request!295Simplify for const
......@@ -179,6 +179,35 @@ class _ConstantOperator(Operator):
def __repr__(self):
return 'ConstantOperator <- {}'.format(self.domain.keys())
class _ConstantOperator2(Operator):
def __init__(self, target, constant_output):
from ..sugar import makeDomain
self._target = makeDomain(target)
dom_keys = set(target.keys())-set(constant_output.domain.keys())
self._domain = makeDomain({key: self._target[key] for key in dom_keys})
self._constant_output = constant_output
def apply(self, x):
from ..linearization import Linearization
self._check_input(x)
if not isinstance(x, Linearization):
return x.unite(self._constant_output)
from .simple_linear_operators import _PartialExtractor
op = _PartialExtractor(self.target, x.jac.target).adjoint
val = x.val.unite(self._constant_output)
assert val.domain is self.target
assert val.domain is op.target
return x.new(val, op(x.jac))
def __repr__(self):
return 'ConstantOperator2: {} <- {}'.format(self.target.keys(), self.domain.keys())
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment