diff --git a/test/test_energies.py b/test/test_energies.py index 681faac364f36b99c18c02584603c0325116b24f..b5812855f37a4c046fa7e9adc62cbceb2933bbe5 100644 --- a/test/test_energies.py +++ b/test/test_energies.py @@ -74,4 +74,29 @@ def test_varcov_gaussian_energy(dtype, with_mask): "logicov": rve.dtype_complex2float(dtype, force=True), } rve.operator_equality(op1.nifty_equivalent, op1, ntries=5, domain_dtype=dt, rtol=2e-5) - rve.operator_equality(op2.nifty_equivalent, op2, ntries=5, domain_dtype=dt, rtol=2e-5) + + +def test_varcov_mask(with_mask): + dtype = np.float64 + dom = ift.UnstructuredDomain([4]) + mean = ift.from_random(dom, dtype=dtype) + if with_mask: + rng = np.random.default_rng(42) + mask = (rng.uniform(0, 1, mean.shape) > 0.5).astype(np.uint8) + mask = ift.makeField(mean.domain, mask) + else: + mask = None + + logwgt = ift.from_random(dom) + if with_mask: + logwgt = logwgt.val_rw() + logwgt[~mask.val.astype(bool)] = np.nan + logwgt = ift.makeField(dom, logwgt) + + logwgtop = ift.makeOp(logwgt).ducktape("wgts").ducktape_left("logicov") + + lh = rve.VariableCovarianceDiagonalGaussianLikelihood( + mean, "signal", "logicov", mask=mask, nthreads=1 + ).partial_insert(logwgtop) + + assert not np.isnan(lh(ift.from_random(lh.domain)).val)