diff --git a/nifty4/operators/chain_operator.py b/nifty4/operators/chain_operator.py index 783f9c151a70dd2c7e82d8eb78c1b3316651e88a..299ad7f64103d33460c56ddce13778ed6a90e235 100644 --- a/nifty4/operators/chain_operator.py +++ b/nifty4/operators/chain_operator.py @@ -61,9 +61,9 @@ class ChainOperator(LinearOperator): # try to absorb the factor into a DiagonalOperator for i in range(len(opsnew)): if isinstance(opsnew[i], DiagonalOperator): - opsnew[i] = DiagonalOperator(opsnew[i].diagonal*fct, - domain=opsnew[i].domain, - spaces=opsnew[i]._spaces) + opsnew[i] = DiagonalOperator(None, opsnew[i].domain, + opsnew[i]._spaces, + opsnew[i]._ldiag*fct) fct = 1. break if fct != 1: @@ -75,12 +75,13 @@ class ChainOperator(LinearOperator): for op in ops: if (len(opsnew) > 0 and isinstance(opsnew[-1], DiagonalOperator) and - isinstance(op, DiagonalOperator) and - op._spaces == opsnew[-1]._spaces): - opsnew[-1] = DiagonalOperator(opsnew[-1].diagonal * - op.diagonal, - domain=opsnew[-1].domain, - spaces=opsnew[-1]._spaces) + isinstance(op, DiagonalOperator)): + if opsnew[-1]._spaces is None or op._spaces is None: + spc = None + else: + spc = tuple(set(opsnew[-1]._spaces) | set(op._spaces)) + ldiag = opsnew[-1]._ldiag * op._ldiag + opsnew[-1] = DiagonalOperator(None, op.domain, spc, ldiag) else: opsnew.append(op) ops = opsnew diff --git a/nifty4/operators/diagonal_operator.py b/nifty4/operators/diagonal_operator.py index 345cede075e3586fdc0f3f3db8e2a0ae301ab296..daaf3b7f49176dd5023f80f9b4c80fb2a1de9dac 100644 --- a/nifty4/operators/diagonal_operator.py +++ b/nifty4/operators/diagonal_operator.py @@ -54,9 +54,15 @@ class DiagonalOperator(EndomorphicOperator): This shortcoming will hopefully be fixed in the future. """ - def __init__(self, diagonal, domain=None, spaces=None): + def __init__(self, diagonal, domain=None, spaces=None, _ldiag=None): super(DiagonalOperator, self).__init__() + if _ldiag is not None: # very special hack + self._ldiag = _ldiag + self._domain = domain + self._spaces = spaces + return + if not isinstance(diagonal, Field): raise TypeError("Field object required") if domain is None: @@ -116,11 +122,6 @@ class DiagonalOperator(EndomorphicOperator): else: return Field(x.domain, val=x.val/self._ldiag.conj()) - @property - def diagonal(self): - """ Returns the diagonal of the Operator.""" - return self._diagonal - @property def domain(self): return self._domain @@ -131,12 +132,13 @@ class DiagonalOperator(EndomorphicOperator): @property def inverse(self): - return DiagonalOperator(1./self._diagonal, self._domain, self._spaces) + return DiagonalOperator(None, self._domain, self._spaces, + 1./self._ldiag) @property def adjoint(self): - return DiagonalOperator(self._diagonal.conjugate(), self._domain, - self._spaces) + return DiagonalOperator(None, self._domain, + self._spaces, self._ldiag.conjugate()) def process_sample(self, sample): if np.issubdtype(self._ldiag.dtype, np.complexfloating): diff --git a/nifty4/operators/sum_operator.py b/nifty4/operators/sum_operator.py index 6e8a253c19b27cdc02a54aeeae190d0e82b3f52c..1c223835aa79f4c606a51a68dd2bc98023a77e0c 100644 --- a/nifty4/operators/sum_operator.py +++ b/nifty4/operators/sum_operator.py @@ -71,9 +71,9 @@ class SumOperator(LinearOperator): for i in range(len(opsnew)): if isinstance(opsnew[i], DiagonalOperator): sum *= (-1 if negnew[i] else 1) - opsnew[i] = DiagonalOperator(opsnew[i].diagonal+sum, - domain=opsnew[i].domain, - spaces=opsnew[i]._spaces) + opsnew[i] = DiagonalOperator(None, opsnew[i].domain, + opsnew[i]._spaces, + opsnew[i]._ldiag+sum) sum = 0. break if sum != 0: @@ -89,14 +89,14 @@ class SumOperator(LinearOperator): for i in range(len(ops)): if not processed[i]: if isinstance(ops[i], DiagonalOperator): - diag = ops[i].diagonal*(-1 if neg[i] else 1) + ldiag = ops[i]._ldiag*(-1 if neg[i] else 1) for j in range(i+1, len(ops)): if (isinstance(ops[j], DiagonalOperator) and ops[i]._spaces == ops[j]._spaces): - diag += ops[j].diagonal*(-1 if neg[j] else 1) + ldiag += ops[j]._ldiag*(-1 if neg[j] else 1) processed[j] = True - opsnew.append(DiagonalOperator(diag, ops[i].domain, - ops[i]._spaces)) + opsnew.append(DiagonalOperator(None, ops[i].domain, + ops[i]._spaces, ldiag)) negnew.append(False) else: opsnew.append(ops[i]) diff --git a/test/test_operators/test_diagonal_operator.py b/test/test_operators/test_diagonal_operator.py index 5f3174404c866ff907d7fe915a3b9ae8c60a4a1b..5569c761702383def26f8e37ba06020704c9d3de 100644 --- a/test/test_operators/test_diagonal_operator.py +++ b/test/test_operators/test_diagonal_operator.py @@ -90,5 +90,5 @@ class DiagonalOperator_Tests(unittest.TestCase): def test_diagonal(self, space): diag = ift.Field.from_random('normal', domain=space) D = ift.DiagonalOperator(diag) - diag_op = D.diagonal + diag_op = D(ift.Field.full(space, 1.)) assert_allclose(diag.to_global_data(), diag_op.to_global_data())