Commit 7d140528 authored by Martin Reinecke's avatar Martin Reinecke

move special operators into contractions.py for now

parent 9a04be9d
Pipeline #30008 passed with stages
in 1 minute and 36 seconds
...@@ -11,12 +11,9 @@ from .smoothness_operator import SmoothnessOperator ...@@ -11,12 +11,9 @@ from .smoothness_operator import SmoothnessOperator
from .power_distributor import PowerDistributor from .power_distributor import PowerDistributor
from .inversion_enabler import InversionEnabler from .inversion_enabler import InversionEnabler
from .sandwich_operator import SandwichOperator from .sandwich_operator import SandwichOperator
from .outer_operator import OuterOperator
from .row_operator import RowOperator
__all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator", __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"DiagonalOperator", "HarmonicTransformOperator", "FFTOperator", "DiagonalOperator", "HarmonicTransformOperator", "FFTOperator",
"FFTSmoothingOperator", "GeometryRemover", "FFTSmoothingOperator", "GeometryRemover",
"LaplaceOperator", "SmoothnessOperator", "PowerDistributor", "LaplaceOperator", "SmoothnessOperator", "PowerDistributor",
"InversionEnabler", "SandwichOperator", "InversionEnabler", "SandwichOperator"]
"OuterOperator", "RowOperator"]
from .diagonal_operator import DiagonalOperator
def OuterOperator(field, row_operator):
return DiagonalOperator(field) * row_operator
from ..domain_tuple import DomainTuple
from ..field import Field
from .endomorphic_operator import EndomorphicOperator
class RowOperator(EndomorphicOperator):
def __init__(self, field):
super(RowOperator, self).__init__()
if not isinstance(field, Field):
raise TypeError("Field object required")
self._field = field
self._domain = DomainTuple.make(field.domain)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.full(self.target, self._field.vdot(x))
else:
return self._field * x.sum()
@property
def domain(self):
return self._domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
from ..extra.operator_tests import consistency_check from ..extra.operator_tests import consistency_check
from ..operators import OuterOperator, RowOperator, SandwichOperator from ..operators import EndomorphicOperator, DiagonalOperator, SandwichOperator
from ..field import Field
from .add import SymbolicAdd from .add import SymbolicAdd
from .constant import SymbolicZero from .constant import SymbolicZero
from .symbolic_tensor import SymbolicTensor from .symbolic_tensor import SymbolicTensor
from ..domain_tuple import DomainTuple
class SymbolicChainLinOps(SymbolicTensor): class SymbolicChainLinOps(SymbolicTensor):
...@@ -190,6 +192,35 @@ class SymbolicQuad(SymbolicTensor): ...@@ -190,6 +192,35 @@ class SymbolicQuad(SymbolicTensor):
return SymbolicSandwich(self._thing.derivative) return SymbolicSandwich(self._thing.derivative)
class RowOperator(EndomorphicOperator):
def __init__(self, field):
super(RowOperator, self).__init__()
if not isinstance(field, Field):
raise TypeError("Field object required")
self._field = field
self._domain = DomainTuple.make(field.domain)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.full(self.target, self._field.vdot(x))
else:
return self._field * x.sum()
@property
def domain(self):
return self._domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def OuterOperator(field, row_operator):
return DiagonalOperator(field) * row_operator
class SymbolicOuterProd(SymbolicTensor): class SymbolicOuterProd(SymbolicTensor):
def __init__(self, snd, fst): def __init__(self, snd, fst):
""" Computes A = fst*snd """ """ Computes A = fst*snd """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment