diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index cd124a14b518e6594fb68259dbda9a2119ba9a6b..e41c3eee6f4398876886982ce2b4abcdc4713225 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -38,6 +38,7 @@ from ..operators.operator import Operator from ..operators.simple_linear_operators import ducktape from ..probing import StatCalculator from ..sugar import full, makeDomain, makeField, makeOp +from .. import utilities def _reshaper(x, N): @@ -255,7 +256,7 @@ class _Distributor(LinearOperator): res = x[self._dofdex] else: res = np.zeros(self._tgt(mode).shape, dtype=x.dtype) - res[self._dofdex] = x + res = utilities.special_add_at(res, 0, self._dofdex, x) return makeField(self._tgt(mode), res) diff --git a/test/test_operators/test_correlated_fields.py b/test/test_operators/test_correlated_fields.py index cc9b73212d2a4e091c5957fcad68de063f27b4fe..26352c49f40d6013c6e974443507268862cad516 100644 --- a/test/test_operators/test_correlated_fields.py +++ b/test/test_operators/test_correlated_fields.py @@ -32,8 +32,26 @@ def _stats(op, samples): return sc.mean.val, sc.var.ptw("sqrt").val -@pmp('sspace', [ift.RGSpace(4), ift.RGSpace((4, 4), (0.123, 0.4)), - ift.HPSpace(8), ift.GLSpace(4)]) +@pmp('dofdex', [[0, 0], [0, 1]]) +@pmp('seed', [12, 3]) +def testDistributor(dofdex, seed): + with ift.random.Context(seed): + dom = ift.RGSpace(3) + N_copies = max(dofdex) + 1 + distributed_target = ift.makeDomain( + (ift.UnstructuredDomain(len(dofdex)), dom)) + target = ift.makeDomain((ift.UnstructuredDomain(N_copies), dom)) + op = ift.library.correlated_fields._Distributor( + dofdex, target, distributed_target) + ift.extra.consistency_check(op) + + +@pmp('sspace', [ + ift.RGSpace(4), + ift.RGSpace((4, 4), (0.123, 0.4)), + ift.HPSpace(8), + ift.GLSpace(4) +]) @pmp('N', [0, 2]) def testAmplitudesInvariants(sspace, N): fsspace = ift.RGSpace((12,), (0.4,)) @@ -94,7 +112,9 @@ def testAmplitudesInvariants(sspace, N): return for ampl in fa.normalized_amplitudes: - ift.extra.check_jacobian_consistency(ampl, ift.from_random(ampl.domain), + ift.extra.check_jacobian_consistency(ampl, + ift.from_random(ampl.domain), ntries=10) - ift.extra.check_jacobian_consistency(op, ift.from_random(op.domain), + ift.extra.check_jacobian_consistency(op, + ift.from_random(op.domain), ntries=10)