diff --git a/nifty5/operators/mask_operator.py b/nifty5/operators/mask_operator.py index 189c87cbbd0d262e1a7ada1414501d5920679293..0ff519fbc6aed7b5d99b1830b874277bab683891 100644 --- a/nifty5/operators/mask_operator.py +++ b/nifty5/operators/mask_operator.py @@ -16,33 +16,35 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. +import numpy as np + from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain from ..field import Field -from ..sugar import full from .linear_operator import LinearOperator class MaskOperator(LinearOperator): - def __init__(self, domain, target, xy): - self._domain = DomainTuple.make(domain) - # TODO Takes a field (boolean or 0/1) - # TODO Add MultiFields (output MultiField of unstructured domains) + def __init__(self, mask): + if not isinstance(mask, Field): + raise TypeError - assert len(xy.shape) == 2 - assert xy.shape[1] == 2 - self._target = UnstructuredDomain(xy.shape[0]) + self._domain = DomainTuple.make(mask.domain) + self._mask = np.logical_not(mask.to_global_data()) + self._target = DomainTuple.make(UnstructuredDomain(self._mask.sum())) - self._xs = xy.T[0] - self._ys = xy.T[1] + def data_indices(self): + return np.indices(self.domain.shape).transpose((1, 2, 0))[self._mask] def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: - res = x.val[self._xs, self._ys] + res = x.to_global_data()[self._mask] return Field(self.target, res) - res = full(self.domain, 0.) - res[self._xs, self._ys] = x.val + x = x.to_global_data() + res = np.empty(self.domain.shape, x.dtype) + res[self._mask] = x + res[~self._mask] = 0 return Field(self.domain, res) @property diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index 116cb3419f33bc6711456f703f0bf73c047e4eee..13fd6650473665a2efe76c75adc52098f9073799 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -61,6 +61,17 @@ class Consistency_Tests(unittest.TestCase): op = ift.HarmonicTransformOperator(sp) ift.extra.consistency_check(op, dtype, dtype) + @expand(product(_p_spaces, [np.float64, np.complex128])) + def testMask(self, sp, dtype): + # Create mask + f = ift.from_random('normal', sp).val + mask = np.zeros_like(f) + mask[f > 0] = 1 + mask = ift.Field(sp, mask) + # Test MaskOperator + op = ift.MaskOperator(mask) + ift.extra.consistency_check(op, dtype, dtype) + @expand(product(_h_spaces+_p_spaces, [np.float64, np.complex128])) def testDiagonal(self, sp, dtype): op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,