diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index 5e50c8c4b7e719070f5233d936b8442b55c7938c..824981efae8cb6e2019a6b0158bdf912c3608fdd 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -11,7 +11,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -# Copyright(C) 2013-2019 Max-Planck-Society +# Copyright(C) 2013-2020 Max-Planck-Society # Authors: Philipp Frank, Philipp Arras, Philipp Haim # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. @@ -25,7 +25,6 @@ from ..domain_tuple import DomainTuple from ..domains.power_space import PowerSpace from ..domains.unstructured_domain import UnstructuredDomain from ..field import Field -from ..linearization import Linearization from ..logger import logger from ..multi_field import MultiField from ..operators.adder import Adder @@ -244,10 +243,9 @@ class _SpecialSum(EndomorphicOperator): class _Distributor(LinearOperator): def __init__(self, dofdex, domain, target): - self._dofdex = dofdex - - self._target = makeDomain(target) - self._domain = makeDomain(domain) + self._dofdex = np.array(dofdex) + self._target = DomainTuple.make(target) + self._domain = DomainTuple.make(domain) self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): @@ -256,7 +254,7 @@ class _Distributor(LinearOperator): if mode == self.TIMES: res = x[self._dofdex] else: - res = np.empty(self._tgt(mode).shape) + res = np.zeros(self._tgt(mode).shape, dtype=x.dtype) res[self._dofdex] = x return makeField(self._tgt(mode), res) diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index 459952de6cdb770f9846301d70bb02a810825edb..cd600be0dd10e45dac0576c965644367f67e842e 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -326,3 +326,11 @@ def testSlowFieldAdapter(seed): dom = {'a': ift.RGSpace(1), 'b': ift.RGSpace(2)} op = ift.operators.simple_linear_operators._SlowFieldAdapter(dom, 'a') ift.extra.consistency_check(op) + + +@pmp('dofdex', [(0,), (1,), (0, 1), (1, 0)]) +def testCorFldDistr(dofdex): + tgt = ift.UnstructuredDomain(len(dofdex)) + dom = ift.UnstructuredDomain(2) + op = ift.library.correlated_fields._Distributor(dofdex, dom, tgt) + ift.extra.consistency_check(op)