Commit 7d140528 authored by Martin Reinecke's avatar Martin Reinecke

move special operators into contractions.py for now

parent 9a04be9d
Pipeline #30008 passed with stages
in 1 minute and 36 seconds
......@@ -11,12 +11,9 @@ from .smoothness_operator import SmoothnessOperator
from .power_distributor import PowerDistributor
from .inversion_enabler import InversionEnabler
from .sandwich_operator import SandwichOperator
from .outer_operator import OuterOperator
from .row_operator import RowOperator
__all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"DiagonalOperator", "HarmonicTransformOperator", "FFTOperator",
"FFTSmoothingOperator", "GeometryRemover",
"LaplaceOperator", "SmoothnessOperator", "PowerDistributor",
"InversionEnabler", "SandwichOperator",
"OuterOperator", "RowOperator"]
"InversionEnabler", "SandwichOperator"]
from .diagonal_operator import DiagonalOperator
def OuterOperator(field, row_operator):
return DiagonalOperator(field) * row_operator
from ..domain_tuple import DomainTuple
from ..field import Field
from .endomorphic_operator import EndomorphicOperator
class RowOperator(EndomorphicOperator):
def __init__(self, field):
super(RowOperator, self).__init__()
if not isinstance(field, Field):
raise TypeError("Field object required")
self._field = field
self._domain = DomainTuple.make(field.domain)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.full(self.target, self._field.vdot(x))
else:
return self._field * x.sum()
@property
def domain(self):
return self._domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
from ..extra.operator_tests import consistency_check
from ..operators import OuterOperator, RowOperator, SandwichOperator
from ..operators import EndomorphicOperator, DiagonalOperator, SandwichOperator
from ..field import Field
from .add import SymbolicAdd
from .constant import SymbolicZero
from .symbolic_tensor import SymbolicTensor
from ..domain_tuple import DomainTuple
class SymbolicChainLinOps(SymbolicTensor):
......@@ -190,6 +192,35 @@ class SymbolicQuad(SymbolicTensor):
return SymbolicSandwich(self._thing.derivative)
class RowOperator(EndomorphicOperator):
def __init__(self, field):
super(RowOperator, self).__init__()
if not isinstance(field, Field):
raise TypeError("Field object required")
self._field = field
self._domain = DomainTuple.make(field.domain)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return Field.full(self.target, self._field.vdot(x))
else:
return self._field * x.sum()
@property
def domain(self):
return self._domain
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def OuterOperator(field, row_operator):
return DiagonalOperator(field) * row_operator
class SymbolicOuterProd(SymbolicTensor):
def __init__(self, snd, fst):
""" Computes A = fst*snd """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment