From 5d12ab94b20924807ebf37c9c5ba691b945a8cab Mon Sep 17 00:00:00 2001 From: Philipp Arras <c@philipp-arras.de> Date: Wed, 5 Oct 2022 15:43:41 +0200 Subject: [PATCH] Add test --- test/test_energies.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/test/test_energies.py b/test/test_energies.py index 681faac3..b5812855 100644 --- a/test/test_energies.py +++ b/test/test_energies.py @@ -74,4 +74,29 @@ def test_varcov_gaussian_energy(dtype, with_mask): "logicov": rve.dtype_complex2float(dtype, force=True), } rve.operator_equality(op1.nifty_equivalent, op1, ntries=5, domain_dtype=dt, rtol=2e-5) - rve.operator_equality(op2.nifty_equivalent, op2, ntries=5, domain_dtype=dt, rtol=2e-5) + + +def test_varcov_mask(with_mask): + dtype = np.float64 + dom = ift.UnstructuredDomain([4]) + mean = ift.from_random(dom, dtype=dtype) + if with_mask: + rng = np.random.default_rng(42) + mask = (rng.uniform(0, 1, mean.shape) > 0.5).astype(np.uint8) + mask = ift.makeField(mean.domain, mask) + else: + mask = None + + logwgt = ift.from_random(dom) + if with_mask: + logwgt = logwgt.val_rw() + logwgt[~mask.val.astype(bool)] = np.nan + logwgt = ift.makeField(dom, logwgt) + + logwgtop = ift.makeOp(logwgt).ducktape("wgts").ducktape_left("logicov") + + lh = rve.VariableCovarianceDiagonalGaussianLikelihood( + mean, "signal", "logicov", mask=mask, nthreads=1 + ).partial_insert(logwgtop) + + assert not np.isnan(lh(ift.from_random(lh.domain)).val) -- GitLab