Commit 127e4a3e authored by Martin Reinecke's avatar Martin Reinecke

re-introduce automatic creation of ScalingOperators

parent 13263d5f
Pipeline #30663 passed with stages
in 1 minute and 22 seconds
...@@ -28,8 +28,8 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau, ...@@ -28,8 +28,8 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau,
power = exp(0.5*tau) power = exp(0.5*tau)
position = ht(Distributor(power)*xi) position = ht(Distributor(power)*xi)
linearization = makeOp(nonlinearity.derivative(position)) linearization = makeOp(nonlinearity.derivative(position))
return (makeOp(0.5, Instrument.target) * Instrument * linearization * ht * return (0.5 * Instrument * linearization * ht * makeOp(xi) * Distributor *
makeOp(xi) * Distributor * makeOp(power)) makeOp(power))
class NonlinearPowerEnergy(Energy): class NonlinearPowerEnergy(Energy):
...@@ -138,6 +138,5 @@ class NonlinearPowerEnergy(Energy): ...@@ -138,6 +138,5 @@ class NonlinearPowerEnergy(Energy):
self.position, xi_sample) self.position, xi_sample)
op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse
result = op if result is None else result + op result = op if result is None else result + op
result = (result*makeOp(1./len(self.xi_sample_list), result.domain) + result = result*(1./len(self.xi_sample_list)) + self.T
self.T)
return InversionEnabler(result, self.inverter) return InversionEnabler(result, self.inverter)
...@@ -116,18 +116,51 @@ class LinearOperator(NiftyMetaBase()): ...@@ -116,18 +116,51 @@ class LinearOperator(NiftyMetaBase()):
the adjoint of this operator.""" the adjoint of this operator."""
return self._flip_modes(self.ADJOINT_BIT) return self._flip_modes(self.ADJOINT_BIT)
@staticmethod
def _toOperator(thing, dom):
from .scaling_operator import ScalingOperator
if isinstance(thing, LinearOperator):
return thing
if np.isscalar(thing):
return ScalingOperator(thing, dom)
return NotImplemented
def __mul__(self, other): def __mul__(self, other):
from .chain_operator import ChainOperator from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
return self
other = self._toOperator(other, self.domain)
return ChainOperator.make([self, other]) return ChainOperator.make([self, other])
def __rmul__(self, other):
from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
return self
other = self._toOperator(other, self.target)
return ChainOperator.make([other, self])
def __add__(self, other): def __add__(self, other):
from .sum_operator import SumOperator from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, False]) return SumOperator.make([self, other], [False, False])
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other): def __sub__(self, other):
from .sum_operator import SumOperator from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, True]) return SumOperator.make([self, other], [False, True])
def __rsub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator.make([other, self], [False, True])
@abc.abstractproperty @abc.abstractproperty
def capability(self): def capability(self):
"""int : the supported operation modes """int : the supported operation modes
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from .scaling_operator import ScalingOperator from .scaling_operator import ScalingOperator
from .laplace_operator import LaplaceOperator from .laplace_operator import LaplaceOperator
from ..sugar import makeOp
def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
...@@ -52,4 +51,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): ...@@ -52,4 +51,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
if strength == 0.: if strength == 0.:
return ScalingOperator(0., domain) return ScalingOperator(0., domain)
laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space) laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
return makeOp(strength**2, laplace.domain)*laplace.adjoint*laplace return (strength**2)*laplace.adjoint*laplace
...@@ -24,7 +24,6 @@ from .multi.multi_field import MultiField ...@@ -24,7 +24,6 @@ from .multi.multi_field import MultiField
from .multi.block_diagonal_operator import BlockDiagonalOperator from .multi.block_diagonal_operator import BlockDiagonalOperator
from .multi.multi_domain import MultiDomain from .multi.multi_domain import MultiDomain
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.scaling_operator import ScalingOperator
from .operators.power_distributor import PowerDistributor from .operators.power_distributor import PowerDistributor
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
from . import dobj, utilities from . import dobj, utilities
...@@ -234,16 +233,12 @@ def makeDomain(domain): ...@@ -234,16 +233,12 @@ def makeDomain(domain):
return DomainTuple.make(domain) return DomainTuple.make(domain)
def makeOp(input, domain=None): def makeOp(input):
if isinstance(input, Field): if isinstance(input, Field):
return DiagonalOperator(input) return DiagonalOperator(input)
if isinstance(input, MultiField): if isinstance(input, MultiField):
return BlockDiagonalOperator({key: makeOp(val) return BlockDiagonalOperator({key: makeOp(val)
for key, val in input.items()}) for key, val in input.items()})
if np.isscalar(input):
if domain is None:
raise ValueError("domain needs to be set")
return ScalingOperator(input, domain)
raise NotImplementedError raise NotImplementedError
# Arithmetic functions working on Fields # Arithmetic functions working on Fields
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment