Commit 13c8fe93 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement MaskOperator

parent 4a29e1ec
......@@ -16,33 +16,35 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..sugar import full
from .linear_operator import LinearOperator
class MaskOperator(LinearOperator):
def __init__(self, domain, target, xy):
self._domain = DomainTuple.make(domain)
# TODO Takes a field (boolean or 0/1)
# TODO Add MultiFields (output MultiField of unstructured domains)
def __init__(self, mask):
if not isinstance(mask, Field):
raise TypeError
assert len(xy.shape) == 2
assert xy.shape[1] == 2
self._target = UnstructuredDomain(xy.shape[0])
self._domain = DomainTuple.make(mask.domain)
self._mask = np.logical_not(mask.to_global_data())
self._target = DomainTuple.make(UnstructuredDomain(self._mask.sum()))
self._xs = xy.T[0]
self._ys = xy.T[1]
def data_indices(self):
return np.indices(self.domain.shape).transpose((1, 2, 0))[self._mask]
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = x.val[self._xs, self._ys]
res = x.to_global_data()[self._mask]
return Field(self.target, res)
res = full(self.domain, 0.)
res[self._xs, self._ys] = x.val
x = x.to_global_data()
res = np.empty(self.domain.shape, x.dtype)
res[self._mask] = x
res[~self._mask] = 0
return Field(self.domain, res)
@property
......
......@@ -61,6 +61,17 @@ class Consistency_Tests(unittest.TestCase):
op = ift.HarmonicTransformOperator(sp)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_p_spaces, [np.float64, np.complex128]))
def testMask(self, sp, dtype):
# Create mask
f = ift.from_random('normal', sp).val
mask = np.zeros_like(f)
mask[f > 0] = 1
mask = ift.Field(sp, mask)
# Test MaskOperator
op = ift.MaskOperator(mask)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces+_p_spaces, [np.float64, np.complex128]))
def testDiagonal(self, sp, dtype):
op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment