Commit 15e11814 authored by Lukas Platz's avatar Lukas Platz
Browse files

remove SwitchSpacesOperator (superseeded by LinearEinsum)

parent 7e2c809a
Pipeline #75130 passed with stages
in 26 minutes and 4 seconds
......@@ -16,7 +16,6 @@ In addition to the below changes, the following operators were introduced:
* PartialConjugate: Conjugates parts of a multi-field
* SliceOperator: Geometry preserving mask operator
* SplitOperator: Splits a single field into a multi-field
* SwitchSpacesOperator: Permutes the domain entries of fields
FFT convention adjusted
=======================
......
......@@ -43,9 +43,8 @@ from .operators.selection_operators import SliceOperator, SplitOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, ducktape, GeometryRemover, NullOperator,
PartialExtractor, SwitchSpacesOperator)
VdotOperator, ConjugationOperator, Realizer, FieldAdapter, ducktape,
GeometryRemover, NullOperator, PartialExtractor)
from .operators.matrix_product_operator import MatrixProductOperator
from .operators.value_inserter import ValueInserter
from .operators.energy_operators import (
......
......@@ -349,41 +349,3 @@ class PartialExtractor(LinearOperator):
res0 = MultiField.from_dict({key: x[key] for key in x.domain.keys()})
res1 = MultiField.full(self._compldomain, 0.)
return res0.unite(res1)
class SwitchSpacesOperator(LinearOperator):
"""Operator to permutate the domain entries of fields.
Exchanges the entries `space1` and `space2` of the input's domain.
"""
def __init__(self, domain, space1, space2=0):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = DomainTuple.make(domain)
n_spaces = len(self._domain)
if space1 >= n_spaces or space1 < 0 \
or space2 >= n_spaces or space2 < 0:
raise ValueError("invalid space value")
tgt = list(self._domain)
tgt[space2] = self._domain[space1]
tgt[space1] = self._domain[space2]
self._target = DomainTuple.make(tgt)
dom_axes = self._domain.axes
tgt_axes = self._target.axes
self._axes_dom = dom_axes[space1] + dom_axes[space2]
self._axes_tgt = tgt_axes[space2] + tgt_axes[space1]
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
val = np.moveaxis(x.val, self._axes_dom, self._axes_tgt)
dom = self._target
else:
val = np.moveaxis(x.val, self._axes_tgt, self._axes_dom)
dom = self._domain
return Field(dom, val)
......@@ -326,20 +326,3 @@ def testSlowFieldAdapter(seed):
dom = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)}
op = ift.operators.simple_linear_operators._SlowFieldAdapter(dom, 'a')
ift.extra.consistency_check(op)
@pmp('sp1', [0, 2])
@pmp('sp2', [1])
@pmp('seed', [12, 3])
def testSwitchSpacesOperator(sp1, sp2, seed):
with ift.random.Context(seed):
dom1 = ift.RGSpace(1)
dom2 = ift.RGSpace((2, 2))
dom3 = ift.RGSpace(3)
dom = ift.DomainTuple.make([dom1, dom2, dom3])
op = ift.SwitchSpacesOperator(dom, sp1, sp2)
tgt = list(dom)
tgt[sp1] = dom[sp2]
tgt[sp2] = dom[sp1]
assert op.target == ift.DomainTuple.make(tgt)
ift.extra.consistency_check(op)
......@@ -92,6 +92,44 @@ def test_linear_einsum_contraction(space1, space2, dtype, n_invocations=10):
assert_allclose(le.adjoint(r_adj).val, le_ift.adjoint(r_adj).val)
class _SwitchSpacesOperator(ift.LinearOperator):
"""Operator to permutate the domain entries of fields.
Exchanges the entries `space1` and `space2` of the input's domain.
"""
def __init__(self, domain, space1, space2=0):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = ift.DomainTuple.make(domain)
n_spaces = len(self._domain)
if space1 >= n_spaces or space1 < 0 \
or space2 >= n_spaces or space2 < 0:
raise ValueError("invalid space value")
tgt = list(self._domain)
tgt[space2] = self._domain[space1]
tgt[space1] = self._domain[space2]
self._target = ift.DomainTuple.make(tgt)
dom_axes = self._domain.axes
tgt_axes = self._target.axes
self._axes_dom = dom_axes[space1] + dom_axes[space2]
self._axes_tgt = tgt_axes[space2] + tgt_axes[space1]
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
val = np.moveaxis(x.val, self._axes_dom, self._axes_tgt)
dom = self._target
else:
val = np.moveaxis(x.val, self._axes_tgt, self._axes_dom)
dom = self._domain
return ift.Field(dom, val)
def test_multi_linear_einsum_outer(
space1, space2, dtype, n_invocations=10, ntries=100
):
......@@ -116,7 +154,7 @@ def test_multi_linear_einsum_outer(
ift.full(mf_dom["dom01"], 1.), ift.DomainTuple.make(mf_dom["dom02"][1])
)
# SwitchSpacesOperator is equivalent to LinearEinsum with "ij->ji"
mle_ift = ift.SwitchSpacesOperator(
mle_ift = _SwitchSpacesOperator(
outer_i.target, 1
) @ outer_i @ ift.FieldAdapter(mf_dom["dom01"], "dom01") * ift.FieldAdapter(
mf_dom["dom02"], "dom02"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment