diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 375995f63104e6d70e68fa7388be48098c0c7472..b836041aeeef5d3ea9637da9206f02b1038b55f7 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -179,6 +179,35 @@ class _ConstantOperator(Operator): def __repr__(self): return 'ConstantOperator <- {}'.format(self.domain.keys()) + + +class _ConstantOperator2(Operator): + def __init__(self, target, constant_output): + from ..sugar import makeDomain + self._target = makeDomain(target) + dom_keys = set(target.keys())-set(constant_output.domain.keys()) + self._domain = makeDomain({key: self._target[key] for key in dom_keys}) + self._constant_output = constant_output + + def apply(self, x): + from ..linearization import Linearization + self._check_input(x) + if not isinstance(x, Linearization): + return x.unite(self._constant_output) + from .simple_linear_operators import _PartialExtractor + + op = _PartialExtractor(self.target, x.jac.target).adjoint + val = x.val.unite(self._constant_output) + + assert val.domain is self.target + assert val.domain is op.target + + return x.new(val, op(x.jac)) + + def __repr__(self): + return 'ConstantOperator2: {} <- {}'.format(self.target.keys(), self.domain.keys()) + + class _FunctionApplier(Operator): def __init__(self, domain, funcname): from ..sugar import makeDomain