diff --git a/nifty5/multi_field.py b/nifty5/multi_field.py index 52bf3508dc8187f11311ff17ab68e72db59da92e..1acaabd8b5ecd899d175737b8fd4254c093043c2 100644 --- a/nifty5/multi_field.py +++ b/nifty5/multi_field.py @@ -219,10 +219,6 @@ class MultiField(object): return MultiField.from_dict({key: self[key] for key in subset.keys() if key in self}) - def extract_by_keys(self, keys): - keys = set(self.domain.keys()) & set(keys) - return MultiField.from_dict({key: self[key] for key in keys}) - def unite(self, other): """Merges two MultiFields on potentially different MultiDomains. diff --git a/nifty5/operators/chain_operator.py b/nifty5/operators/chain_operator.py index 922063605265e6581727de5a74edd005038cb81c..af46f4fefe5d1914745fd8164bbdb6549b05ed2a 100644 --- a/nifty5/operators/chain_operator.py +++ b/nifty5/operators/chain_operator.py @@ -138,14 +138,7 @@ class ChainOperator(LinearOperator): subs = "\n".join(sub.__repr__() for sub in self._ops) return "ChainOperator:\n" + utilities.indent(subs) - def simplify_for_constant_input(self, c_inp): - if c_inp is None: - return None, self - if c_inp.domain == self.domain: - from .operator import _ConstantOperator - op = _ConstantOperator(self.domain, self(c_inp)) - return op(c_inp), op - + def _simplify_for_constant_input_nontrivial(self, c_inp): from ..multi_domain import MultiDomain if not isinstance(self._domain, MultiDomain): return None, self diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 374614641ac1b9f726a51afb65642ea44b8634af..8ee4970769d414f05d7b9fa755811654251e3035 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -147,10 +147,15 @@ class Operator(metaclass=NiftyMeta): return self.__class__.__name__ def simplify_for_constant_input(self, c_inp): - if c_inp is None or c_inp.domain != self.domain: + if c_inp is None: return None, self - op = _ConstantOperator(self.domain, self(c_inp)) - return op(c_inp), op + if c_inp.domain == self.domain: + op = _ConstantOperator(self.domain, self(c_inp)) + return op(c_inp), op + return self._simplify_for_constant_input_nontrivial(c_inp) + + def _simplify_for_constant_input_nontrivial(self, c_inp): + return None, self for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan', @@ -222,33 +227,6 @@ class _ConstantOperator(Operator): return 'ConstantOperator <- {}'.format(self.domain.keys()) -class _ConstantOperator2(Operator): - def __init__(self, target, constant_output): - from ..sugar import makeDomain - self._target = makeDomain(target) - dom_keys = set(target.keys())-set(constant_output.domain.keys()) - self._domain = makeDomain({key: self._target[key] for key in dom_keys}) - self._constant_output = constant_output - - def apply(self, x): - from ..linearization import Linearization - self._check_input(x) - if not isinstance(x, Linearization): - return x.unite(self._constant_output) - from .simple_linear_operators import _PartialExtractor - - op = _PartialExtractor(self.target, x.jac.target).adjoint - val = x.val.unite(self._constant_output) - - assert val.domain is self.target - assert val.domain is op.target - - return x.new(val, op(x.jac)) - - def __repr__(self): - return 'ConstantOperator2: {} <- {}'.format(self.target.keys(), self.domain.keys()) - - class _FunctionApplier(Operator): def __init__(self, domain, funcname): from ..sugar import makeDomain @@ -321,13 +299,7 @@ class _OpChain(_CombinedOperator): x = op(x) return x - def simplify_for_constant_input(self, c_inp): - if c_inp is None: - return None, self - if c_inp.domain == self.domain: - op = _ConstantOperator(self.domain, self(c_inp)) - return op(c_inp), op - + def _simplify_for_constant_input_nontrivial(self, c_inp): from ..multi_domain import MultiDomain if not isinstance(self._domain, MultiDomain): return None, self @@ -370,13 +342,7 @@ class _OpProd(Operator): makeOp(lin2._val)(lin1._jac), False) return lin1.new(lin1._val*lin2._val, op(x.jac)) - def simplify_for_constant_input(self, c_inp): - if c_inp is None: - return None, self - if c_inp.domain == self.domain: - op = _ConstantOperator(self.domain, self(c_inp)) - return op(c_inp), op - + def _simplify_for_constant_input_nontrivial(self, c_inp): f1, o1 = self._op1.simplify_for_constant_input( c_inp.extract_part(self._op1.domain)) f2, o2 = self._op2.simplify_for_constant_input( @@ -423,13 +389,7 @@ class _OpSum(Operator): res = res.add_metric(lin1._metric + lin2._metric) return res - def simplify_for_constant_input(self, c_inp): - if c_inp is None: - return None, self - if c_inp.domain == self.domain: - op = _ConstantOperator(self.domain, self(c_inp)) - return op(c_inp), op - + def _simplify_for_constant_input_nontrivial(self, c_inp): f1, o1 = self._op1.simplify_for_constant_input( c_inp.extract_part(self._op1.domain)) f2, o2 = self._op2.simplify_for_constant_input( diff --git a/nifty5/operators/sum_operator.py b/nifty5/operators/sum_operator.py index 675fa73e0feb7d8660c0f1fc3baa57d294d3160b..e621979e1ec6770c2fc8822c32d637e498e389ca 100644 --- a/nifty5/operators/sum_operator.py +++ b/nifty5/operators/sum_operator.py @@ -201,14 +201,7 @@ class SumOperator(LinearOperator): subs = "\n".join(sub.__repr__() for sub in self._ops) return "SumOperator:\n"+indent(subs) - def simplify_for_constant_input(self, c_inp): - if c_inp is None: - return None, self - if c_inp.domain == self.domain: - from .operator import _ConstantOperator - op = _ConstantOperator(self.domain, self(c_inp)) - return op(c_inp), op - + def _simplify_for_constant_input_nontrivial(self, c_inp): f=[] o=[] for op in self._ops: