Commit 2fbd9937 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

hackish implementation for aggressive combination of DiagonalOperators

parent 1883e5d9
...@@ -61,9 +61,9 @@ class ChainOperator(LinearOperator): ...@@ -61,9 +61,9 @@ class ChainOperator(LinearOperator):
# try to absorb the factor into a DiagonalOperator # try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)): for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator): if isinstance(opsnew[i], DiagonalOperator):
opsnew[i] = DiagonalOperator(opsnew[i].diagonal*fct, opsnew[i] = DiagonalOperator(None, opsnew[i].domain,
domain=opsnew[i].domain, opsnew[i]._spaces,
spaces=opsnew[i]._spaces) opsnew[i]._ldiag*fct)
fct = 1. fct = 1.
break break
if fct != 1: if fct != 1:
...@@ -75,12 +75,13 @@ class ChainOperator(LinearOperator): ...@@ -75,12 +75,13 @@ class ChainOperator(LinearOperator):
for op in ops: for op in ops:
if (len(opsnew) > 0 and if (len(opsnew) > 0 and
isinstance(opsnew[-1], DiagonalOperator) and isinstance(opsnew[-1], DiagonalOperator) and
isinstance(op, DiagonalOperator) and isinstance(op, DiagonalOperator)):
op._spaces == opsnew[-1]._spaces): if opsnew[-1]._spaces is None or op._spaces is None:
opsnew[-1] = DiagonalOperator(opsnew[-1].diagonal * spc = None
op.diagonal, else:
domain=opsnew[-1].domain, spc = tuple(set(opsnew[-1]._spaces) | set(op._spaces))
spaces=opsnew[-1]._spaces) ldiag = opsnew[-1]._ldiag * op._ldiag
opsnew[-1] = DiagonalOperator(None, op.domain, spc, ldiag)
else: else:
opsnew.append(op) opsnew.append(op)
ops = opsnew ops = opsnew
......
...@@ -54,9 +54,15 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -54,9 +54,15 @@ class DiagonalOperator(EndomorphicOperator):
This shortcoming will hopefully be fixed in the future. This shortcoming will hopefully be fixed in the future.
""" """
def __init__(self, diagonal, domain=None, spaces=None): def __init__(self, diagonal, domain=None, spaces=None, _ldiag=None):
super(DiagonalOperator, self).__init__() super(DiagonalOperator, self).__init__()
if _ldiag is not None: # very special hack
self._ldiag = _ldiag
self._domain = domain
self._spaces = spaces
return
if not isinstance(diagonal, Field): if not isinstance(diagonal, Field):
raise TypeError("Field object required") raise TypeError("Field object required")
if domain is None: if domain is None:
...@@ -116,11 +122,6 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -116,11 +122,6 @@ class DiagonalOperator(EndomorphicOperator):
else: else:
return Field(x.domain, val=x.val/self._ldiag.conj()) return Field(x.domain, val=x.val/self._ldiag.conj())
@property
def diagonal(self):
""" Returns the diagonal of the Operator."""
return self._diagonal
@property @property
def domain(self): def domain(self):
return self._domain return self._domain
...@@ -131,12 +132,13 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -131,12 +132,13 @@ class DiagonalOperator(EndomorphicOperator):
@property @property
def inverse(self): def inverse(self):
return DiagonalOperator(1./self._diagonal, self._domain, self._spaces) return DiagonalOperator(None, self._domain, self._spaces,
1./self._ldiag)
@property @property
def adjoint(self): def adjoint(self):
return DiagonalOperator(self._diagonal.conjugate(), self._domain, return DiagonalOperator(None, self._domain,
self._spaces) self._spaces, self._ldiag.conjugate())
def process_sample(self, sample): def process_sample(self, sample):
if np.issubdtype(self._ldiag.dtype, np.complexfloating): if np.issubdtype(self._ldiag.dtype, np.complexfloating):
......
...@@ -71,9 +71,9 @@ class SumOperator(LinearOperator): ...@@ -71,9 +71,9 @@ class SumOperator(LinearOperator):
for i in range(len(opsnew)): for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator): if isinstance(opsnew[i], DiagonalOperator):
sum *= (-1 if negnew[i] else 1) sum *= (-1 if negnew[i] else 1)
opsnew[i] = DiagonalOperator(opsnew[i].diagonal+sum, opsnew[i] = DiagonalOperator(None, opsnew[i].domain,
domain=opsnew[i].domain, opsnew[i]._spaces,
spaces=opsnew[i]._spaces) opsnew[i]._ldiag+sum)
sum = 0. sum = 0.
break break
if sum != 0: if sum != 0:
...@@ -89,14 +89,14 @@ class SumOperator(LinearOperator): ...@@ -89,14 +89,14 @@ class SumOperator(LinearOperator):
for i in range(len(ops)): for i in range(len(ops)):
if not processed[i]: if not processed[i]:
if isinstance(ops[i], DiagonalOperator): if isinstance(ops[i], DiagonalOperator):
diag = ops[i].diagonal*(-1 if neg[i] else 1) ldiag = ops[i]._ldiag*(-1 if neg[i] else 1)
for j in range(i+1, len(ops)): for j in range(i+1, len(ops)):
if (isinstance(ops[j], DiagonalOperator) and if (isinstance(ops[j], DiagonalOperator) and
ops[i]._spaces == ops[j]._spaces): ops[i]._spaces == ops[j]._spaces):
diag += ops[j].diagonal*(-1 if neg[j] else 1) ldiag += ops[j]._ldiag*(-1 if neg[j] else 1)
processed[j] = True processed[j] = True
opsnew.append(DiagonalOperator(diag, ops[i].domain, opsnew.append(DiagonalOperator(None, ops[i].domain,
ops[i]._spaces)) ops[i]._spaces, ldiag))
negnew.append(False) negnew.append(False)
else: else:
opsnew.append(ops[i]) opsnew.append(ops[i])
......
...@@ -90,5 +90,5 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -90,5 +90,5 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_diagonal(self, space): def test_diagonal(self, space):
diag = ift.Field.from_random('normal', domain=space) diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag) D = ift.DiagonalOperator(diag)
diag_op = D.diagonal diag_op = D(ift.Field.full(space, 1.))
assert_allclose(diag.to_global_data(), diag_op.to_global_data()) assert_allclose(diag.to_global_data(), diag_op.to_global_data())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment