Commit 03f59cc3 authored by Philipp Arras's avatar Philipp Arras

DomainDistributor -> ContractionOperator

parent 63789fe6
...@@ -22,7 +22,7 @@ from .operators.operator import Operator ...@@ -22,7 +22,7 @@ from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_distributor import DomainDistributor from .operators.contraction_operator import ContractionOperator
from .operators.endomorphic_operator import EndomorphicOperator from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform from .operators.exp_transform import ExpTransform
from .operators.harmonic_operators import ( from .operators.harmonic_operators import (
......
...@@ -21,8 +21,8 @@ from __future__ import absolute_import, division, print_function ...@@ -21,8 +21,8 @@ from __future__ import absolute_import, division, print_function
from ..compat import * from ..compat import *
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import FieldAdapter from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import exp from ..sugar import exp
...@@ -65,8 +65,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial, ...@@ -65,8 +65,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1) pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1)
pd = pd_spatial(pd_energy) pd = pd_spatial(pd_energy)
dom_distr_spatial = DomainDistributor(pd.domain, 0) dom_distr_spatial = ContractionOperator(pd.domain, 0).adjoint
dom_distr_energy = DomainDistributor(pd.domain, 1) dom_distr_energy = ContractionOperator(pd.domain, 1).adjoint
a_spatial = dom_distr_spatial(amplitude_model_spatial) a_spatial = dom_distr_spatial(amplitude_model_spatial)
a_energy = dom_distr_energy(amplitude_model_energy) a_energy = dom_distr_energy(amplitude_model_energy)
......
...@@ -27,40 +27,38 @@ from ..field import Field ...@@ -27,40 +27,38 @@ from ..field import Field
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
class DomainDistributor(LinearOperator): class ContractionOperator(LinearOperator):
"""A linear operator which broadcasts a field to a larger domain. """A linear operator which sums up fields into the direction of subspaces.
This DomainDistributor broadcasts a field which is defined on a This ContractionOperator sums up a field with is defined on a DomainTuple
DomainTuple to a DomainTuple which contains the former as a subset. The to a DomainTuple which contains the former as a subset.
entries of the field are copied such that they are constant in the
direction of the new spaces.
Parameters Parameters
---------- ----------
target : Domain, tuple of Domain or DomainTuple domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int spaces : int or tuple of int
The elements of "target" which are taken as domain. The elements of "domain" which are taken as target.
""" """
def __init__(self, target, spaces): def __init__(self, domain, spaces):
self._target = DomainTuple.make(target) self._domain = DomainTuple.make(domain)
self._spaces = utilities.parse_spaces(spaces, len(self._target)) self._spaces = utilities.parse_spaces(spaces, len(self._domain))
self._domain = [ self._target = [
tgt for i, tgt in enumerate(self._target) if i in self._spaces dom for i, dom in enumerate(self._domain) if i in self._spaces
] ]
self._domain = DomainTuple.make(self._domain) self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if mode == self.ADJOINT_TIMES:
ldat = x.local_data if 0 in self._spaces else x.to_global_data() ldat = x.local_data if 0 in self._spaces else x.to_global_data()
shp = [] shp = []
for i, tgt in enumerate(self._target): for i, dom in enumerate(self._domain):
tmp = tgt.shape if i > 0 else tgt.local_shape tmp = dom.shape if i > 0 else dom.local_shape
shp += tmp if i in self._spaces else (1,)*len(tgt.shape) shp += tmp if i in self._spaces else (1,)*len(dom.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape) ldat = np.broadcast_to(ldat.reshape(shp), self._domain.local_shape)
return Field.from_local_data(self._target, ldat) return Field.from_local_data(self._domain, ldat)
else: else:
return x.sum( return x.sum(
[s for s in range(len(x.domain)) if s not in self._spaces]) [s for s in range(len(x.domain)) if s not in self._spaces])
...@@ -188,10 +188,10 @@ class Consistency_Tests(unittest.TestCase): ...@@ -188,10 +188,10 @@ class Consistency_Tests(unittest.TestCase):
@expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)], @expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)],
[np.float64, np.complex128])) [np.float64, np.complex128]))
def testDomainDistributor(self, spaces, dtype): def testContractionOperator(self, spaces, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5), dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5),
ift.HPSpace(4)) ift.HPSpace(4))
op = ift.DomainDistributor(dom, spaces) op = ift.ContractionOperator(dom, spaces)
ift.extra.consistency_check(op, dtype, dtype) ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [np.float64, np.complex128])) @expand(product([0, 2], [np.float64, np.complex128]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment