Skip to content
Snippets Groups Projects
Commit 03f59cc3 authored by Philipp Arras's avatar Philipp Arras
Browse files

DomainDistributor -> ContractionOperator

parent 63789fe6
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_distributor import DomainDistributor
from .operators.contraction_operator import ContractionOperator
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
from .operators.harmonic_operators import (
......
......@@ -21,8 +21,8 @@ from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import exp
......@@ -65,8 +65,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1)
pd = pd_spatial(pd_energy)
dom_distr_spatial = DomainDistributor(pd.domain, 0)
dom_distr_energy = DomainDistributor(pd.domain, 1)
dom_distr_spatial = ContractionOperator(pd.domain, 0).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 1).adjoint
a_spatial = dom_distr_spatial(amplitude_model_spatial)
a_energy = dom_distr_energy(amplitude_model_energy)
......
......@@ -27,40 +27,38 @@ from ..field import Field
from .linear_operator import LinearOperator
class DomainDistributor(LinearOperator):
"""A linear operator which broadcasts a field to a larger domain.
class ContractionOperator(LinearOperator):
"""A linear operator which sums up fields into the direction of subspaces.
This DomainDistributor broadcasts a field which is defined on a
DomainTuple to a DomainTuple which contains the former as a subset. The
entries of the field are copied such that they are constant in the
direction of the new spaces.
This ContractionOperator sums up a field with is defined on a DomainTuple
to a DomainTuple which contains the former as a subset.
Parameters
----------
target : Domain, tuple of Domain or DomainTuple
domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
The elements of "target" which are taken as domain.
The elements of "domain" which are taken as target.
"""
def __init__(self, target, spaces):
self._target = DomainTuple.make(target)
self._spaces = utilities.parse_spaces(spaces, len(self._target))
self._domain = [
tgt for i, tgt in enumerate(self._target) if i in self._spaces
def __init__(self, domain, spaces):
self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._domain))
self._target = [
dom for i, dom in enumerate(self._domain) if i in self._spaces
]
self._domain = DomainTuple.make(self._domain)
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
if mode == self.ADJOINT_TIMES:
ldat = x.local_data if 0 in self._spaces else x.to_global_data()
shp = []
for i, tgt in enumerate(self._target):
tmp = tgt.shape if i > 0 else tgt.local_shape
shp += tmp if i in self._spaces else (1,)*len(tgt.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape)
return Field.from_local_data(self._target, ldat)
for i, dom in enumerate(self._domain):
tmp = dom.shape if i > 0 else dom.local_shape
shp += tmp if i in self._spaces else (1,)*len(dom.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._domain.local_shape)
return Field.from_local_data(self._domain, ldat)
else:
return x.sum(
[s for s in range(len(x.domain)) if s not in self._spaces])
......@@ -188,10 +188,10 @@ class Consistency_Tests(unittest.TestCase):
@expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)],
[np.float64, np.complex128]))
def testDomainDistributor(self, spaces, dtype):
def testContractionOperator(self, spaces, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5),
ift.HPSpace(4))
op = ift.DomainDistributor(dom, spaces)
op = ift.ContractionOperator(dom, spaces)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [np.float64, np.complex128]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment