Commit e0d1dc5f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

allow chaninig and addin/subtracting between operators, fields and scalars

parent ff2eaec7
Pipeline #23254 failed with stage
in 4 minutes and 6 seconds
......@@ -452,8 +452,11 @@ class Field(object):
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self.domain, tval)
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self.domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self.domain, tval)
raise NotImplementedError
def __add__(self, other):
return self._binary_helper(other, op='__add__')
......
......@@ -5,84 +5,46 @@ from ..operators.linear_operator import LinearOperator
class LinearizedSignalResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, power, m):
super(LinearizedSignalResponse, self).__init__()
self.Instrument = Instrument
self.FFT = FFT
self.power = power
position = FFT.adjoint_times(self.power*m)
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
tmp = self.FFT.adjoint_times(self.power*x)
tmp *= self.linearization
return self.Instrument(tmp)
def _adjoint_times(self, x):
tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.power
return tmp
position = FFT.adjoint_times(power*m)
self._op = (Instrument * nonlinearity.derivative(position) *
FFT.adjoint * power)
@property
def domain(self):
return self.FFT.target
return self._op.domain
@property
def target(self):
return self.Instrument.target
return self._op.target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
return self._op.capability
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode & self.TIMES else self._adjoint_times(x)
return self._op.apply(x, mode)
class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
super(LinearizedPowerResponse, self).__init__()
self.Instrument = Instrument
self.FFT = FFT
self.Projection = Projection
self.power = exp(0.5*t)
self.m = m
position = FFT.adjoint_times(
self.Projection.adjoint_times(self.power) * self.m)
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
tmp = self.Projection.adjoint_times(self.power*x)
tmp *= self.m
tmp = self.FFT.adjoint_times(tmp)
tmp *= self.linearization
tmp = self.Instrument(tmp)
tmp *= 0.5
return tmp
def _adjoint_times(self, x):
tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.m.conjugate()
tmp = self.Projection(tmp)
tmp *= self.power
tmp *= 0.5
return tmp
power = exp(0.5*t)
position = FFT.adjoint_times(Projection.adjoint_times(power) * m)
linearization = nonlinearity.derivative(position)
self._op = (0.5 * Instrument * linearization * FFT.adjoint * m *
Projection.adjoint * power)
@property
def domain(self):
return self.power.domain
return self._op.domain
@property
def target(self):
return self.Instrument.target
return self._op.target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
return self._op.capability
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode & self.TIMES else self._adjoint_times(x)
return self._op.apply(x, mode)
......@@ -20,6 +20,7 @@ import abc
from ..utilities import NiftyMeta
from ..field import Field
from future.utils import with_metaclass
import numpy as np
class LinearOperator(with_metaclass(
......@@ -91,18 +92,46 @@ class LinearOperator(with_metaclass(
from .adjoint_operator import AdjointOperator
return AdjointOperator(self)
@staticmethod
def _toOperator(thing, dom):
from .diagonal_operator import DiagonalOperator
from .scaling_operator import ScalingOperator
if isinstance(thing, LinearOperator):
return thing
if isinstance(thing, Field):
return DiagonalOperator(thing)
if np.isscalar(thing):
return ScalingOperator(thing, dom)
raise NotImplementedError
def __mul__(self, other):
from .chain_operator import ChainOperator
other = self._toOperator(other, self.domain)
return ChainOperator(self, other)
def __rmul__(self, other):
from .chain_operator import ChainOperator
other = self._toOperator(other, self.target)
return ChainOperator(other, self)
def __add__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator(self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator(self, other, neg=True)
def __rsub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator(other, self, neg=True)
def supports(self, ops):
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment