Commit 38f2fc5c authored by Martin Reinecke's avatar Martin Reinecke

improve optimizations

parent 2fbd9937
Pipeline #26304 passed with stage
in 5 minutes and 25 seconds
......@@ -61,9 +61,7 @@ class ChainOperator(LinearOperator):
# try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
opsnew[i] = DiagonalOperator(None, opsnew[i].domain,
opsnew[i]._spaces,
opsnew[i]._ldiag*fct)
opsnew[i] = opsnew[i]._scale(fct)
fct = 1.
break
if fct != 1:
......@@ -76,12 +74,7 @@ class ChainOperator(LinearOperator):
if (len(opsnew) > 0 and
isinstance(opsnew[-1], DiagonalOperator) and
isinstance(op, DiagonalOperator)):
if opsnew[-1]._spaces is None or op._spaces is None:
spc = None
else:
spc = tuple(set(opsnew[-1]._spaces) | set(op._spaces))
ldiag = opsnew[-1]._ldiag * op._ldiag
opsnew[-1] = DiagonalOperator(None, op.domain, spc, ldiag)
opsnew[-1] = opsnew[-1]._combine_prod(op)
else:
opsnew.append(op)
ops = opsnew
......
......@@ -54,15 +54,9 @@ class DiagonalOperator(EndomorphicOperator):
This shortcoming will hopefully be fixed in the future.
"""
def __init__(self, diagonal, domain=None, spaces=None, _ldiag=None):
def __init__(self, diagonal, domain=None, spaces=None):
super(DiagonalOperator, self).__init__()
if _ldiag is not None: # very special hack
self._ldiag = _ldiag
self._domain = domain
self._spaces = spaces
return
if not isinstance(diagonal, Field):
raise TypeError("Field object required")
if domain is None:
......@@ -85,24 +79,60 @@ class DiagonalOperator(EndomorphicOperator):
if self._spaces == tuple(range(len(self._domain))):
self._spaces = None # shortcut
self._diagonal = diagonal.lock()
if self._spaces is not None:
active_axes = []
for space_index in self._spaces:
active_axes += self._domain.axes[space_index]
if self._spaces[0] == 0:
self._ldiag = self._diagonal.local_data
self._ldiag = diagonal.local_data
else:
self._ldiag = self._diagonal.to_global_data()
self._ldiag = diagonal.to_global_data()
locshape = dobj.local_shape(self._domain.shape, 0)
self._reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(locshape)]
self._ldiag = self._ldiag.reshape(self._reshaper)
else:
self._ldiag = self._diagonal.local_data
self._ldiag = diagonal.local_data
self._ldiag.flags.writeable = False
def _skeleton(self, spc):
res = DiagonalOperator.__new__(DiagonalOperator)
res._domain = self._domain
if self._spaces is None or spc is None:
res._spaces = None
else:
res._spaces = tuple(set(self._spaces) | set(spc))
return res
def _scale(self, fct):
if not np.isscalar(fct):
raise TypeError("scalar value required")
res = self._skeleton(())
res._ldiag = self._ldiag*fct
return res
def _add(self, sum):
if not np.isscalar(sum):
raise TypeError("scalar value required")
res = self._skeleton(())
res._ldiag = self._ldiag + sum
return res
def _combine_prod(self, op):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
res = self._skeleton(op._spaces)
res._ldiag = self._ldiag*op._ldiag
return res
def _combine_sum(self, op, selfneg, opneg):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
res = self._skeleton(op._spaces)
res._ldiag = (self._ldiag * (-1 if selfneg else 1) +
op._ldiag * (-1 if opneg else 1))
return res
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -132,13 +162,15 @@ class DiagonalOperator(EndomorphicOperator):
@property
def inverse(self):
return DiagonalOperator(None, self._domain, self._spaces,
1./self._ldiag)
res = self._skeleton(())
res._ldiag = 1./self._ldiag
return res
@property
def adjoint(self):
return DiagonalOperator(None, self._domain,
self._spaces, self._ldiag.conjugate())
res = self._skeleton(())
res._ldiag = self._ldiag.conjugate()
return res
def process_sample(self, sample):
if np.issubdtype(self._ldiag.dtype, np.complexfloating):
......
......@@ -71,9 +71,7 @@ class SumOperator(LinearOperator):
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
sum *= (-1 if negnew[i] else 1)
opsnew[i] = DiagonalOperator(None, opsnew[i].domain,
opsnew[i]._spaces,
opsnew[i]._ldiag+sum)
opsnew[i] = opsnew[i]._add(sum)
sum = 0.
break
if sum != 0:
......@@ -89,15 +87,15 @@ class SumOperator(LinearOperator):
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], DiagonalOperator):
ldiag = ops[i]._ldiag*(-1 if neg[i] else 1)
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if (isinstance(ops[j], DiagonalOperator) and
ops[i]._spaces == ops[j]._spaces):
ldiag += ops[j]._ldiag*(-1 if neg[j] else 1)
if isinstance(ops[j], DiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(DiagonalOperator(None, ops[i].domain,
ops[i]._spaces, ldiag))
negnew.append(False)
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment