Commit 2fbd9937 authored by Martin Reinecke's avatar Martin Reinecke

hackish implementation for aggressive combination of DiagonalOperators

parent 1883e5d9
......@@ -61,9 +61,9 @@ class ChainOperator(LinearOperator):
# try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
opsnew[i] = DiagonalOperator(opsnew[i].diagonal*fct,
domain=opsnew[i].domain,
spaces=opsnew[i]._spaces)
opsnew[i] = DiagonalOperator(None, opsnew[i].domain,
opsnew[i]._spaces,
opsnew[i]._ldiag*fct)
fct = 1.
break
if fct != 1:
......@@ -75,12 +75,13 @@ class ChainOperator(LinearOperator):
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], DiagonalOperator) and
isinstance(op, DiagonalOperator) and
op._spaces == opsnew[-1]._spaces):
opsnew[-1] = DiagonalOperator(opsnew[-1].diagonal *
op.diagonal,
domain=opsnew[-1].domain,
spaces=opsnew[-1]._spaces)
isinstance(op, DiagonalOperator)):
if opsnew[-1]._spaces is None or op._spaces is None:
spc = None
else:
spc = tuple(set(opsnew[-1]._spaces) | set(op._spaces))
ldiag = opsnew[-1]._ldiag * op._ldiag
opsnew[-1] = DiagonalOperator(None, op.domain, spc, ldiag)
else:
opsnew.append(op)
ops = opsnew
......
......@@ -54,9 +54,15 @@ class DiagonalOperator(EndomorphicOperator):
This shortcoming will hopefully be fixed in the future.
"""
def __init__(self, diagonal, domain=None, spaces=None):
def __init__(self, diagonal, domain=None, spaces=None, _ldiag=None):
super(DiagonalOperator, self).__init__()
if _ldiag is not None: # very special hack
self._ldiag = _ldiag
self._domain = domain
self._spaces = spaces
return
if not isinstance(diagonal, Field):
raise TypeError("Field object required")
if domain is None:
......@@ -116,11 +122,6 @@ class DiagonalOperator(EndomorphicOperator):
else:
return Field(x.domain, val=x.val/self._ldiag.conj())
@property
def diagonal(self):
""" Returns the diagonal of the Operator."""
return self._diagonal
@property
def domain(self):
return self._domain
......@@ -131,12 +132,13 @@ class DiagonalOperator(EndomorphicOperator):
@property
def inverse(self):
return DiagonalOperator(1./self._diagonal, self._domain, self._spaces)
return DiagonalOperator(None, self._domain, self._spaces,
1./self._ldiag)
@property
def adjoint(self):
return DiagonalOperator(self._diagonal.conjugate(), self._domain,
self._spaces)
return DiagonalOperator(None, self._domain,
self._spaces, self._ldiag.conjugate())
def process_sample(self, sample):
if np.issubdtype(self._ldiag.dtype, np.complexfloating):
......
......@@ -71,9 +71,9 @@ class SumOperator(LinearOperator):
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
sum *= (-1 if negnew[i] else 1)
opsnew[i] = DiagonalOperator(opsnew[i].diagonal+sum,
domain=opsnew[i].domain,
spaces=opsnew[i]._spaces)
opsnew[i] = DiagonalOperator(None, opsnew[i].domain,
opsnew[i]._spaces,
opsnew[i]._ldiag+sum)
sum = 0.
break
if sum != 0:
......@@ -89,14 +89,14 @@ class SumOperator(LinearOperator):
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], DiagonalOperator):
diag = ops[i].diagonal*(-1 if neg[i] else 1)
ldiag = ops[i]._ldiag*(-1 if neg[i] else 1)
for j in range(i+1, len(ops)):
if (isinstance(ops[j], DiagonalOperator) and
ops[i]._spaces == ops[j]._spaces):
diag += ops[j].diagonal*(-1 if neg[j] else 1)
ldiag += ops[j]._ldiag*(-1 if neg[j] else 1)
processed[j] = True
opsnew.append(DiagonalOperator(diag, ops[i].domain,
ops[i]._spaces))
opsnew.append(DiagonalOperator(None, ops[i].domain,
ops[i]._spaces, ldiag))
negnew.append(False)
else:
opsnew.append(ops[i])
......
......@@ -90,5 +90,5 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_diagonal(self, space):
diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag)
diag_op = D.diagonal
diag_op = D(ift.Field.full(space, 1.))
assert_allclose(diag.to_global_data(), diag_op.to_global_data())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment