From 5d12ab94b20924807ebf37c9c5ba691b945a8cab Mon Sep 17 00:00:00 2001
From: Philipp Arras <c@philipp-arras.de>
Date: Wed, 5 Oct 2022 15:43:41 +0200
Subject: [PATCH] Add test

---
 test/test_energies.py | 27 ++++++++++++++++++++++++++-
 1 file changed, 26 insertions(+), 1 deletion(-)

diff --git a/test/test_energies.py b/test/test_energies.py
index 681faac3..b5812855 100644
--- a/test/test_energies.py
+++ b/test/test_energies.py
@@ -74,4 +74,29 @@ def test_varcov_gaussian_energy(dtype, with_mask):
         "logicov": rve.dtype_complex2float(dtype, force=True),
     }
     rve.operator_equality(op1.nifty_equivalent, op1, ntries=5, domain_dtype=dt, rtol=2e-5)
-    rve.operator_equality(op2.nifty_equivalent, op2, ntries=5, domain_dtype=dt, rtol=2e-5)
+
+
+def test_varcov_mask(with_mask):
+    dtype = np.float64
+    dom = ift.UnstructuredDomain([4])
+    mean = ift.from_random(dom, dtype=dtype)
+    if with_mask:
+        rng = np.random.default_rng(42)
+        mask = (rng.uniform(0, 1, mean.shape) > 0.5).astype(np.uint8)
+        mask = ift.makeField(mean.domain, mask)
+    else:
+        mask = None
+
+    logwgt = ift.from_random(dom)
+    if with_mask:
+        logwgt = logwgt.val_rw()
+        logwgt[~mask.val.astype(bool)] = np.nan
+        logwgt = ift.makeField(dom, logwgt)
+
+    logwgtop = ift.makeOp(logwgt).ducktape("wgts").ducktape_left("logicov")
+
+    lh = rve.VariableCovarianceDiagonalGaussianLikelihood(
+        mean, "signal", "logicov", mask=mask, nthreads=1
+    ).partial_insert(logwgtop)
+
+    assert not np.isnan(lh(ift.from_random(lh.domain)).val)
-- 
GitLab