From 1f5a367c3242739c547b8a7b4bd16600ebdbc9d3 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Mon, 8 Jan 2018 11:18:07 +0100 Subject: [PATCH] tweak SumOperator; fix a few other problems --- demos/paper_demos/cartesian_wiener_filter.py | 2 +- nifty/field.py | 4 +- nifty/operators/linear_operator.py | 6 +- nifty/operators/sum_operator.py | 86 +++++++++++++++++--- 4 files changed, 79 insertions(+), 19 deletions(-) diff --git a/demos/paper_demos/cartesian_wiener_filter.py b/demos/paper_demos/cartesian_wiener_filter.py index 95468b64a..137a0e6ad 100644 --- a/demos/paper_demos/cartesian_wiener_filter.py +++ b/demos/paper_demos/cartesian_wiener_filter.py @@ -70,7 +70,7 @@ if __name__ == "__main__": mask_2[N2_10*7:N2_10*9] = 0. mask_2 = ift.Field(signal_space_2, ift.dobj.from_global_data(mask_2)) - R = ift.ResponseOperator(signal_domain, spaces=(0, 1), + R = ift.ResponseOperator(signal_domain, sigma=(response_sigma_1, response_sigma_2), exposure=(mask_1, mask_2)) data_domain = R.target diff --git a/nifty/field.py b/nifty/field.py index 4428b2026..467561d3f 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -370,10 +370,10 @@ class Field(object): return self.copy() def __neg__(self): - return Field(self._domain, -self.val, self.dtype) + return Field(self._domain, -self.val) def __abs__(self): - return Field(self._domain, dobj.abs(self.val), self.dtype) + return Field(self._domain, dobj.abs(self.val)) def _contraction_helper(self, op, spaces): if spaces is None: diff --git a/nifty/operators/linear_operator.py b/nifty/operators/linear_operator.py index a27c29461..536aa107a 100644 --- a/nifty/operators/linear_operator.py +++ b/nifty/operators/linear_operator.py @@ -104,7 +104,7 @@ class LinearOperator(with_metaclass( def __add__(self, other): from .sum_operator import SumOperator other = self._toOperator(other, self.domain) - return SumOperator(self, other) + return SumOperator.make([self, other], [False, False]) def __radd__(self, other): return self.__add__(other) @@ -112,13 +112,13 @@ class LinearOperator(with_metaclass( def __sub__(self, other): from .sum_operator import SumOperator other = self._toOperator(other, self.domain) - return SumOperator(self, other, neg=True) + return SumOperator.make([self, other], [False, True]) # MR FIXME: this might be more complicated ... def __rsub__(self, other): from .sum_operator import SumOperator other = self._toOperator(other, self.domain) - return SumOperator(other, self, neg=True) + return SumOperator.make(other, self, [False, True]) def supports(self, ops): return False diff --git a/nifty/operators/sum_operator.py b/nifty/operators/sum_operator.py index 070c1634e..51178b845 100644 --- a/nifty/operators/sum_operator.py +++ b/nifty/operators/sum_operator.py @@ -20,20 +20,80 @@ from .linear_operator import LinearOperator class SumOperator(LinearOperator): - def __init__(self, op1, op2, neg=False): + def __init__(self, ops, neg, _callingfrommake=False): + if not _callingfrommake: + raise NotImplementedError super(SumOperator, self).__init__() - if op1.domain != op2.domain or op1.target != op2.target: - raise ValueError("domain mismatch") - self._capability = (op1.capability & op2.capability & - (self.TIMES | self.ADJOINT_TIMES)) - neg1 = op1._neg if isinstance(op1, SumOperator) else (False,) - op1 = op1._ops if isinstance(op1, SumOperator) else (op1,) - neg2 = op2._neg if isinstance(op2, SumOperator) else (False,) - op2 = op2._ops if isinstance(op2, SumOperator) else (op2,) - if neg: - neg2 = tuple(not n for n in neg2) - self._ops = op1 + op2 - self._neg = neg1 + neg2 + self._ops = ops + self._neg = neg + self._capability = self.TIMES | self.ADJOINT_TIMES + for op in ops: + self._capability &= op.capability + + @staticmethod + def simplify(ops, neg): + from .scaling_operator import ScalingOperator + from .diagonal_operator import DiagonalOperator + # Step 1: verify domains + for op in ops[1:]: + if op.domain != ops[0].domain or op.target != ops[0].target: + raise ValueError("domain mismatch") + # Step 2: unpack SumOperators + opsnew = [] + negnew = [] + for op, ng in zip (ops, neg): + if isinstance(op, SumOperator): + opsnew += op._ops + if ng: + negtmp += [not n for n in ng] + else: + negtmp += list(ng) + else: + opsnew.append(op) + negnew.append(ng) + ops = opsnew + neg = negnew + # Step 3: collect ScalingOperators + sum = 0. + opsnew = [] + negnew = [] + lastdom = ops[-1].domain + for op, ng in zip(ops, neg): + if isinstance(op, ScalingOperator): + sum += op._factor * (-1 if ng else 1) + else: + opsnew.append(op) + negnew.append(ng) + if sum != 0.: + # try to absorb the factor into a DiagonalOperator + for i in range(len(opsnew)): + if isinstance(opsnew[i], DiagonalOperator): + sum *= (-1 if negnew[i] else 1) + opsnew[i] = DiagonalOperator(opsnew[i].diagonal()+sum, + domain=opsnew[i].domain, + spaces=opsnew[i]._spaces) + sum = 0. + break + if sum != 0: + # have to add the scaling operator at the end + opsnew.append(ScalingOperator(sum, lastdom)) + newnew.append(False) + ops = opsnew + neg = negnew + # Step 4: combine DiagonalOperators where possible + # (TBD) + return ops, neg + + @staticmethod + def make(ops, neg): + ops = tuple(ops) + neg = tuple(neg) + if len(ops)!= len(neg): + raise ValueError("length mismatch between ops and neg") + ops, neg = SumOperator.simplify(ops, neg) + if len(ops) == 1 and not neg[0]: + return ops[0] + return SumOperator(ops, neg, _callingfrommake=True) @property def domain(self): -- GitLab