Skip to content
Snippets Groups Projects
Commit d3e0dbec authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent 4e671932
No related branches found
No related tags found
1 merge request!295Simplify for const
......@@ -219,10 +219,6 @@ class MultiField(object):
return MultiField.from_dict({key: self[key] for key in subset.keys()
if key in self})
def extract_by_keys(self, keys):
keys = set(self.domain.keys()) & set(keys)
return MultiField.from_dict({key: self[key] for key in keys})
def unite(self, other):
"""Merges two MultiFields on potentially different MultiDomains.
......
......@@ -138,14 +138,7 @@ class ChainOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
from .operator import _ConstantOperator
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
......
......@@ -147,10 +147,15 @@ class Operator(metaclass=NiftyMeta):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
if c_inp is None or c_inp.domain != self.domain:
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
return self._simplify_for_constant_input_nontrivial(c_inp)
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
......@@ -222,33 +227,6 @@ class _ConstantOperator(Operator):
return 'ConstantOperator <- {}'.format(self.domain.keys())
class _ConstantOperator2(Operator):
def __init__(self, target, constant_output):
from ..sugar import makeDomain
self._target = makeDomain(target)
dom_keys = set(target.keys())-set(constant_output.domain.keys())
self._domain = makeDomain({key: self._target[key] for key in dom_keys})
self._constant_output = constant_output
def apply(self, x):
from ..linearization import Linearization
self._check_input(x)
if not isinstance(x, Linearization):
return x.unite(self._constant_output)
from .simple_linear_operators import _PartialExtractor
op = _PartialExtractor(self.target, x.jac.target).adjoint
val = x.val.unite(self._constant_output)
assert val.domain is self.target
assert val.domain is op.target
return x.new(val, op(x.jac))
def __repr__(self):
return 'ConstantOperator2: {} <- {}'.format(self.target.keys(), self.domain.keys())
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
from ..sugar import makeDomain
......@@ -321,13 +299,7 @@ class _OpChain(_CombinedOperator):
x = op(x)
return x
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
def _simplify_for_constant_input_nontrivial(self, c_inp):
from ..multi_domain import MultiDomain
if not isinstance(self._domain, MultiDomain):
return None, self
......@@ -370,13 +342,7 @@ class _OpProd(Operator):
makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, op(x.jac))
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
......@@ -423,13 +389,7 @@ class _OpSum(Operator):
res = res.add_metric(lin1._metric + lin2._metric)
return res
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
c_inp.extract_part(self._op1.domain))
f2, o2 = self._op2.simplify_for_constant_input(
......
......@@ -201,14 +201,7 @@ class SumOperator(LinearOperator):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "SumOperator:\n"+indent(subs)
def simplify_for_constant_input(self, c_inp):
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
from .operator import _ConstantOperator
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
def _simplify_for_constant_input_nontrivial(self, c_inp):
f=[]
o=[]
for op in self._ops:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment