Commit 7dc878a0 authored by Martin Reinecke's avatar Martin Reinecke

try to fix _OpSum

parent 12439028
......@@ -55,7 +55,7 @@ class Operator(NiftyMetaBase()):
def __add__(self, x):
if not isinstance(x, Operator):
return NotImplemented
return _OpSum.make([self, x])
return _OpSum(self, x)
def apply(self, x):
raise NotImplementedError
......@@ -152,16 +152,28 @@ class _OpProd(Operator):
return Linearization(lin1._val*lin2._val, op(x.jac))
class _OpSum(_CombinedOperator):
def __init__(self, ops, _callingfrommake=False):
class _OpSum(Operator):
def __init__(self, op1, op2):
from ..sugar import domain_union
super(_OpSum, self).__init__(ops, _callingfrommake)
self._domain = domain_union([op.domain for op in self._ops])
self._target = domain_union([op.target for op in self._ops])
self._domain = domain_union((op1.domain, op2.domain))
self._target = domain_union((op1.target, op2.target))
self._op1 = op1
self._op2 = op2
def apply(self, x):
from ..linearization import Linearization
from ..sugar import makeOp
lin = isinstance(x, Linearization)
v = x._val if lin else x
v1 = v.extract(self._op1.domain)
v2 = v.extract(self._op2.domain)
res = None
for op in self._ops:
tmp = op(x.extract(op.domain))
res = tmp if res is None else res.unite(tmp)
if not lin:
return self._op1(v1).unite(self._op2(v2))
lin1 = self._op1(Linearization.make_var(v1))
lin2 = self._op2(Linearization.make_var(v2))
op = lin1._jac._myadd(lin2._jac, False)
res = Linearization(lin1._val+lin2._val, op(x.jac))
if lin1._metric is not None and lin2._metric is not None:
res = res.add_metric(lin1._metric + lin2._metric)
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment