Commit 50ffd58d authored by Lukas Platz's avatar Lukas Platz

CorrelatedField: fix for `total_N > 1`

parent 8ff10b70
Pipeline #75337 canceled with stages
in 6 minutes and 57 seconds
...@@ -38,6 +38,7 @@ from ..operators.operator import Operator ...@@ -38,6 +38,7 @@ from ..operators.operator import Operator
from ..operators.simple_linear_operators import ducktape from ..operators.simple_linear_operators import ducktape
from ..probing import StatCalculator from ..probing import StatCalculator
from ..sugar import full, makeDomain, makeField, makeOp from ..sugar import full, makeDomain, makeField, makeOp
from .. import utilities
def _reshaper(x, N): def _reshaper(x, N):
...@@ -255,7 +256,7 @@ class _Distributor(LinearOperator): ...@@ -255,7 +256,7 @@ class _Distributor(LinearOperator):
res = x[self._dofdex] res = x[self._dofdex]
else: else:
res = np.zeros(self._tgt(mode).shape, dtype=x.dtype) res = np.zeros(self._tgt(mode).shape, dtype=x.dtype)
res[self._dofdex] = x res = utilities.special_add_at(res, 0, self._dofdex, x)
return makeField(self._tgt(mode), res) return makeField(self._tgt(mode), res)
......
...@@ -32,8 +32,26 @@ def _stats(op, samples): ...@@ -32,8 +32,26 @@ def _stats(op, samples):
return sc.mean.val, sc.var.ptw("sqrt").val return sc.mean.val, sc.var.ptw("sqrt").val
@pmp('sspace', [ift.RGSpace(4), ift.RGSpace((4, 4), (0.123, 0.4)), @pmp('dofdex', [[0, 0], [0, 1]])
ift.HPSpace(8), ift.GLSpace(4)]) @pmp('seed', [12, 3])
def testDistributor(dofdex, seed):
with ift.random.Context(seed):
dom = ift.RGSpace(3)
N_copies = max(dofdex) + 1
distributed_target = ift.makeDomain(
(ift.UnstructuredDomain(len(dofdex)), dom))
target = ift.makeDomain((ift.UnstructuredDomain(N_copies), dom))
op = ift.library.correlated_fields._Distributor(
dofdex, target, distributed_target)
ift.extra.consistency_check(op)
@pmp('sspace', [
ift.RGSpace(4),
ift.RGSpace((4, 4), (0.123, 0.4)),
ift.HPSpace(8),
ift.GLSpace(4)
])
@pmp('N', [0, 2]) @pmp('N', [0, 2])
def testAmplitudesInvariants(sspace, N): def testAmplitudesInvariants(sspace, N):
fsspace = ift.RGSpace((12,), (0.4,)) fsspace = ift.RGSpace((12,), (0.4,))
...@@ -94,7 +112,9 @@ def testAmplitudesInvariants(sspace, N): ...@@ -94,7 +112,9 @@ def testAmplitudesInvariants(sspace, N):
return return
for ampl in fa.normalized_amplitudes: for ampl in fa.normalized_amplitudes:
ift.extra.check_jacobian_consistency(ampl, ift.from_random(ampl.domain), ift.extra.check_jacobian_consistency(ampl,
ift.from_random(ampl.domain),
ntries=10) ntries=10)
ift.extra.check_jacobian_consistency(op, ift.from_random(op.domain), ift.extra.check_jacobian_consistency(op,
ift.from_random(op.domain),
ntries=10) ntries=10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment