Commit d9d0e0b4 authored by Martin Reinecke's avatar Martin Reinecke

more tweaking

parent 6492d74b
Pipeline #23387 passed with stage
in 4 minutes and 35 seconds
......@@ -154,9 +154,14 @@ class data_object(object):
raise ValueError("distributions are incompatible.")
a = a._data
b = b._data
else:
elif np.isscalar(other):
a = a._data
b = other
elif isinstance(other, np.ndarray):
a = a._data
b = other
else:
return NotImplemented
tval = getattr(a, op)(b)
if tval is a:
......
from ..operators import EndomorphicOperator, InversionEnabler
from ..operators import InversionEnabler
from .response_operators import LinearizedPowerResponse
class NonlinearPowerCurvature(EndomorphicOperator):
class _Helper(EndomorphicOperator):
def __init__(self, position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list):
super(NonlinearPowerCurvature._Helper, self).__init__()
self.N = N
self.FFT = FFT
self.Instrument = Instrument
self.T = T
self.sample_list = sample_list
self.position = position
self.Projection = Projection
self.nonlinearity = nonlinearity
@property
def domain(self):
return self.position.domain
@property
def capability(self):
return self.TIMES
def apply(self, x, mode):
self._check_input(x, mode)
result = None
for sample in self.sample_list:
if result is None:
result = self._sample_times(x, sample)
else:
result += self._sample_times(x, sample)
result *= 1./len(self.sample_list)
return result + self.T(x)
def _sample_times(self, x, sample):
LinearizedResponse = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.FFT, self.Projection,
self.position, sample)
return LinearizedResponse.adjoint_times(
self.N.inverse_times(LinearizedResponse(x)))
def __init__(self, position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list, inverter):
super(NonlinearPowerCurvature, self).__init__()
self._op = self._Helper(position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list)
self._op = InversionEnabler(self._op, inverter)
@property
def domain(self):
return self._op.domain
@property
def capability(self):
return self._op.capability
def apply(self, x, mode):
return self._op.apply(x, mode)
def NonlinearPowerCurvature(position, FFT, Instrument, nonlinearity,
Projection, N, T, sample_list, inverter):
result = None
for sample in sample_list:
LinearizedResponse = LinearizedPowerResponse(
Instrument, nonlinearity, FFT, Projection, position, sample)
op = LinearizedResponse.adjoint*N.inverse*LinearizedResponse
result = op if result is None else result + op
result = result*(1./len(sample_list)) + T
return InversionEnabler(result, inverter)
......@@ -20,14 +20,76 @@ from .linear_operator import LinearOperator
class ChainOperator(LinearOperator):
def __init__(self, op1, op2):
def __init__(self, ops, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(ChainOperator, self).__init__()
if op2.target != op1.domain:
raise ValueError("domain mismatch")
self._capability = op1.capability & op2.capability
op1 = op1._ops if isinstance(op1, ChainOperator) else (op1,)
op2 = op2._ops if isinstance(op2, ChainOperator) else (op2,)
self._ops = op1 + op2
self._ops = ops
self._capability = self._all_ops
for op in ops:
self._capability &= op.capability
@staticmethod
def simplify(ops):
from .scaling_operator import ScalingOperator
from .diagonal_operator import DiagonalOperator
# Step 1: verify domains
for i in range(len(ops)-1):
if ops[i+1].target != ops[i].domain:
raise ValueError("domain mismatch")
# Step 2: unpack ChainOperators
opsnew = []
for op in ops:
if isinstance(op, ChainOperator):
opsnew += op._ops
else:
opsnew.append(op)
ops = opsnew
# Step 3: collect ScalingOperators
fct = 1.
opsnew = []
lastdom = ops[-1].domain
for op in ops:
if isinstance(op, ScalingOperator):
fct *= op._factor
else:
opsnew.append(op)
if fct != 1.:
# try to absorb the factor into a DiagonalOperator
for i in range(len(opsnew)):
if isinstance(opsnew[i], DiagonalOperator):
opsnew[i] = DiagonalOperator(opsnew[i].diagonal()*fct,
domain=opsnew[i].domain,
spaces=opsnew[i]._spaces)
fct = 1.
break
if fct != 1:
# have to add the scaling operator at the end
opsnew.append(ScalingOperator(fct, lastdom))
ops = opsnew
# Step 4: combine DiagonalOperators where possible
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], DiagonalOperator) and
isinstance(op, DiagonalOperator) and
op._spaces == opsnew[-1]._spaces):
opsnew[-1] = DiagonalOperator(opsnew[-1].diagonal() *
op.diagonal(),
domain=opsnew[-1].domain,
spaces=opsnew[-1]._spaces)
else:
opsnew.append(op)
ops = opsnew
return ops
@staticmethod
def make(ops):
ops = tuple(ops)
ops = ChainOperator.simplify(ops)
if len(ops) == 1:
return ops[0]
return ChainOperator(ops, _callingfrommake=True)
@property
def domain(self):
......
......@@ -130,5 +130,13 @@ class DiagonalOperator(EndomorphicOperator):
@property
def capability(self):
return (self.TIMES | self.ADJOINT_TIMES |
self.INVERSE_TIMES | self.ADJOINT_INVERSE_TIMES)
return self._all_ops
@property
def inverse(self):
return DiagonalOperator(1./self._diagonal, self._domain, self._spaces)
@property
def adjoint(self):
return DiagonalOperator(self._diagonal.conjugate(), self._domain,
self._spaces)
......@@ -33,6 +33,7 @@ class LinearOperator(with_metaclass(
_adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15)
_addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15)
_backwards = 6
_all_ops = 15
TIMES = 1
ADJOINT_TIMES = 2
INVERSE_TIMES = 4
......@@ -93,12 +94,12 @@ class LinearOperator(with_metaclass(
def __mul__(self, other):
from .chain_operator import ChainOperator
other = self._toOperator(other, self.domain)
return ChainOperator(self, other)
return ChainOperator.make([self, other])
def __rmul__(self, other):
from .chain_operator import ChainOperator
other = self._toOperator(other, self.target)
return ChainOperator(other, self)
return ChainOperator.make([other, self])
def __add__(self, other):
from .sum_operator import SumOperator
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment