Skip to content
Snippets Groups Projects
Commit e82bcf40 authored by Lukas Platz's avatar Lukas Platz
Browse files

SwitchSpacesOperator created

parent 0c70a67f
No related branches found
No related tags found
1 merge request!443Switch spaces operator
...@@ -43,7 +43,7 @@ from .operators.outer_product_operator import OuterProduct ...@@ -43,7 +43,7 @@ from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import ( from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer, VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, ducktape, GeometryRemover, NullOperator, FieldAdapter, ducktape, GeometryRemover, NullOperator,
MatrixProductOperator, PartialExtractor) MatrixProductOperator, PartialExtractor, SwitchSpacesOperator)
from .operators.value_inserter import ValueInserter from .operators.value_inserter import ValueInserter
from .operators.energy_operators import ( from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
......
...@@ -464,3 +464,41 @@ class MatrixProductOperator(EndomorphicOperator): ...@@ -464,3 +464,41 @@ class MatrixProductOperator(EndomorphicOperator):
res = np.tensordot(m, x.val, axes=(mat_axes, self._active_axes)) res = np.tensordot(m, x.val, axes=(mat_axes, self._active_axes))
res = np.moveaxis(res, move_axes, self._active_axes) res = np.moveaxis(res, move_axes, self._active_axes)
return Field(self._domain, res) return Field(self._domain, res)
class SwitchSpacesOperator(LinearOperator):
"""Operator to permutate the domain entries of fields.
Exchanges the entries `space1` and `space2` of the input's domain.
"""
def __init__(self, domain, space1, space2=0):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = DomainTuple.make(domain)
n_spaces = len(self._domain)
if space1 >= n_spaces or space1 < 0 \
or space2 >= n_spaces or space2 < 0:
raise ValueError("invalid space value")
tgt = list(self._domain)
tgt[space2] = self._domain[space1]
tgt[space1] = self._domain[space2]
self._target = DomainTuple.make(tgt)
dom_axes = self._domain.axes
tgt_axes = self._target.axes
self._axes_dom = dom_axes[space1] + dom_axes[space2]
self._axes_tgt = tgt_axes[space2] + tgt_axes[space1]
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
val = np.moveaxis(x.val, self._axes_dom, self._axes_tgt)
dom = self._target
else:
val = np.moveaxis(x.val, self._axes_tgt, self._axes_dom)
dom = self._domain
return Field(dom, val)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment