From 03f59cc3a560a16087bd10e4d4f22aea23d97480 Mon Sep 17 00:00:00 2001 From: Philipp Arras Date: Mon, 20 Aug 2018 11:50:40 +0200 Subject: [PATCH] DomainDistributor -> ContractionOperator --- nifty5/__init__.py | 2 +- nifty5/library/correlated_fields.py | 6 +-- ...distributor.py => contraction_operator.py} | 38 +++++++++---------- test/test_operators/test_adjoint.py | 4 +- 4 files changed, 24 insertions(+), 26 deletions(-) rename nifty5/operators/{domain_distributor.py => contraction_operator.py} (59%) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 52b4b375..2c235976 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -22,7 +22,7 @@ from .operators.operator import Operator from .operators.central_zero_padder import CentralZeroPadder from .operators.diagonal_operator import DiagonalOperator from .operators.distributors import DOFDistributor, PowerDistributor -from .operators.domain_distributor import DomainDistributor +from .operators.contraction_operator import ContractionOperator from .operators.endomorphic_operator import EndomorphicOperator from .operators.exp_transform import ExpTransform from .operators.harmonic_operators import ( diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 78c428e7..26879516 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -21,8 +21,8 @@ from __future__ import absolute_import, division, print_function from ..compat import * from ..domain_tuple import DomainTuple from ..multi_domain import MultiDomain +from ..operators.contraction_operator import ContractionOperator from ..operators.distributors import PowerDistributor -from ..operators.domain_distributor import DomainDistributor from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.simple_linear_operators import FieldAdapter from ..sugar import exp @@ -65,8 +65,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial, pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1) pd = pd_spatial(pd_energy) - dom_distr_spatial = DomainDistributor(pd.domain, 0) - dom_distr_energy = DomainDistributor(pd.domain, 1) + dom_distr_spatial = ContractionOperator(pd.domain, 0).adjoint + dom_distr_energy = ContractionOperator(pd.domain, 1).adjoint a_spatial = dom_distr_spatial(amplitude_model_spatial) a_energy = dom_distr_energy(amplitude_model_energy) diff --git a/nifty5/operators/domain_distributor.py b/nifty5/operators/contraction_operator.py similarity index 59% rename from nifty5/operators/domain_distributor.py rename to nifty5/operators/contraction_operator.py index a01f9498..741f34c3 100644 --- a/nifty5/operators/domain_distributor.py +++ b/nifty5/operators/contraction_operator.py @@ -27,40 +27,38 @@ from ..field import Field from .linear_operator import LinearOperator -class DomainDistributor(LinearOperator): - """A linear operator which broadcasts a field to a larger domain. +class ContractionOperator(LinearOperator): + """A linear operator which sums up fields into the direction of subspaces. - This DomainDistributor broadcasts a field which is defined on a - DomainTuple to a DomainTuple which contains the former as a subset. The - entries of the field are copied such that they are constant in the - direction of the new spaces. + This ContractionOperator sums up a field with is defined on a DomainTuple + to a DomainTuple which contains the former as a subset. Parameters ---------- - target : Domain, tuple of Domain or DomainTuple + domain : Domain, tuple of Domain or DomainTuple spaces : int or tuple of int - The elements of "target" which are taken as domain. + The elements of "domain" which are taken as target. """ - def __init__(self, target, spaces): - self._target = DomainTuple.make(target) - self._spaces = utilities.parse_spaces(spaces, len(self._target)) - self._domain = [ - tgt for i, tgt in enumerate(self._target) if i in self._spaces + def __init__(self, domain, spaces): + self._domain = DomainTuple.make(domain) + self._spaces = utilities.parse_spaces(spaces, len(self._domain)) + self._target = [ + dom for i, dom in enumerate(self._domain) if i in self._spaces ] - self._domain = DomainTuple.make(self._domain) + self._target = DomainTuple.make(self._target) self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): self._check_input(x, mode) - if mode == self.TIMES: + if mode == self.ADJOINT_TIMES: ldat = x.local_data if 0 in self._spaces else x.to_global_data() shp = [] - for i, tgt in enumerate(self._target): - tmp = tgt.shape if i > 0 else tgt.local_shape - shp += tmp if i in self._spaces else (1,)*len(tgt.shape) - ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape) - return Field.from_local_data(self._target, ldat) + for i, dom in enumerate(self._domain): + tmp = dom.shape if i > 0 else dom.local_shape + shp += tmp if i in self._spaces else (1,)*len(dom.shape) + ldat = np.broadcast_to(ldat.reshape(shp), self._domain.local_shape) + return Field.from_local_data(self._domain, ldat) else: return x.sum( [s for s in range(len(x.domain)) if s not in self._spaces]) diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index b6697775..278bf1b9 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -188,10 +188,10 @@ class Consistency_Tests(unittest.TestCase): @expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)], [np.float64, np.complex128])) - def testDomainDistributor(self, spaces, dtype): + def testContractionOperator(self, spaces, dtype): dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5), ift.HPSpace(4)) - op = ift.DomainDistributor(dom, spaces) + op = ift.ContractionOperator(dom, spaces) ift.extra.consistency_check(op, dtype, dtype) @expand(product([0, 2], [np.float64, np.complex128])) -- GitLab