diff --git a/nifty5/operators/mask_operator.py b/nifty5/operators/mask_operator.py index 0ff519fbc6aed7b5d99b1830b874277bab683891..ef61ac2502f8b53e8bd411eefce6a059f2743551 100644 --- a/nifty5/operators/mask_operator.py +++ b/nifty5/operators/mask_operator.py @@ -34,18 +34,21 @@ class MaskOperator(LinearOperator): self._target = DomainTuple.make(UnstructuredDomain(self._mask.sum())) def data_indices(self): - return np.indices(self.domain.shape).transpose((1, 2, 0))[self._mask] + if len(self.domain.shape) == 1: + return np.arange(self.domain.shape[0])[self._mask] + if len(self.domain.shape) == 2: + return np.indices(self.domain.shape).transpose((1, 2, 0))[self._mask] def apply(self, x, mode): self._check_input(x, mode) if mode == self.TIMES: res = x.to_global_data()[self._mask] - return Field(self.target, res) + return Field.from_global_data(self.target, res) x = x.to_global_data() res = np.empty(self.domain.shape, x.dtype) res[self._mask] = x res[~self._mask] = 0 - return Field(self.domain, res) + return Field.from_global_data(self.domain, res) @property def capability(self): diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index 13fd6650473665a2efe76c75adc52098f9073799..9df172961406ce7b825b375ab020083a679d5132 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -64,10 +64,10 @@ class Consistency_Tests(unittest.TestCase): @expand(product(_p_spaces, [np.float64, np.complex128])) def testMask(self, sp, dtype): # Create mask - f = ift.from_random('normal', sp).val + f = ift.from_random('normal', sp).to_global_data() mask = np.zeros_like(f) mask[f > 0] = 1 - mask = ift.Field(sp, mask) + mask = ift.Field.from_global_data(sp, mask) # Test MaskOperator op = ift.MaskOperator(mask) ift.extra.consistency_check(op, dtype, dtype)