Commit 1f5a367c authored by Martin Reinecke's avatar Martin Reinecke

tweak SumOperator; fix a few other problems

parent 7467d04d
Pipeline #23410 passed with stage
in 4 minutes and 38 seconds
......@@ -70,7 +70,7 @@ if __name__ == "__main__":
mask_2[N2_10*7:N2_10*9] = 0.
mask_2 = ift.Field(signal_space_2, ift.dobj.from_global_data(mask_2))
R = ift.ResponseOperator(signal_domain, spaces=(0, 1),
R = ift.ResponseOperator(signal_domain,
sigma=(response_sigma_1, response_sigma_2),
exposure=(mask_1, mask_2))
data_domain = R.target
......
......@@ -370,10 +370,10 @@ class Field(object):
return self.copy()
def __neg__(self):
return Field(self._domain, -self.val, self.dtype)
return Field(self._domain, -self.val)
def __abs__(self):
return Field(self._domain, dobj.abs(self.val), self.dtype)
return Field(self._domain, dobj.abs(self.val))
def _contraction_helper(self, op, spaces):
if spaces is None:
......
......@@ -104,7 +104,7 @@ class LinearOperator(with_metaclass(
def __add__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator(self, other)
return SumOperator.make([self, other], [False, False])
def __radd__(self, other):
return self.__add__(other)
......@@ -112,13 +112,13 @@ class LinearOperator(with_metaclass(
def __sub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator(self, other, neg=True)
return SumOperator.make([self, other], [False, True])
# MR FIXME: this might be more complicated ...
def __rsub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator(other, self, neg=True)
return SumOperator.make(other, self, [False, True])
def supports(self, ops):
return False
......
......@@ -20,20 +20,80 @@ from .linear_operator import LinearOperator
class SumOperator(LinearOperator):
def __init__(self, op1, op2, neg=False):
def __init__(self, ops, neg, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(SumOperator, self).__init__()
if op1.domain != op2.domain or op1.target != op2.target:
raise ValueError("domain mismatch")
self._capability = (op1.capability & op2.capability &
(self.TIMES | self.ADJOINT_TIMES))
neg1 = op1._neg if isinstance(op1, SumOperator) else (False,)
op1 = op1._ops if isinstance(op1, SumOperator) else (op1,)
neg2 = op2._neg if isinstance(op2, SumOperator) else (False,)
op2 = op2._ops if isinstance(op2, SumOperator) else (op2,)
if neg:
neg2 = tuple(not n for n in neg2)
self._ops = op1 + op2
self._neg = neg1 + neg2
self._ops = ops
self._neg = neg
self._capability = self.TIMES | self.ADJOINT_TIMES
for op in ops:
self._capability &= op.capability
@staticmethod
def simplify(ops, neg):
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
for op in ops[1:]:
if op.domain != ops[0].domain or op.target != ops[0].target:
raise ValueError("domain mismatch")
# Step 2: unpack SumOperators
opsnew = []
negnew = []
for op, ng in zip (ops, neg):
if isinstance(op, SumOperator):
opsnew += op._ops
if ng:
negtmp += [not n for n in ng]
else:
negtmp += list(ng)
else:
opsnew.append(op)
negnew.append(ng)
ops = opsnew
neg = negnew
# Step 3: collect ScalingOperators
sum = 0.
opsnew = []
negnew = []
lastdom = ops[-1].domain
for op, ng in zip(ops, neg):
if isinstance(op, ScalingOperator):
sum += op._factor * (-1 if ng else 1)
else:
opsnew.append(op)
negnew.append(ng)
if sum != 0.:
# try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
sum *= (-1 if negnew[i] else 1)
opsnew[i] = DiagonalOperator(opsnew[i].diagonal()+sum,
domain=opsnew[i].domain,
spaces=opsnew[i]._spaces)
sum = 0.
break
if sum != 0:
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(sum, lastdom))
newnew.append(False)
ops = opsnew
neg = negnew
# Step 4: combine DiagonalOperators where possible
# (TBD)
return ops, neg
@staticmethod
def make(ops, neg):
ops = tuple(ops)
neg = tuple(neg)
if len(ops)!= len(neg):
raise ValueError("length mismatch between ops and neg")
ops, neg = SumOperator.simplify(ops, neg)
if len(ops) == 1 and not neg[0]:
return ops[0]
return SumOperator(ops, neg, _callingfrommake=True)
@property
def domain(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment