Commit 0de313ef authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'mask_operator' into 'NIFTy_5'

Add mask operator

See merge request ift/nifty-dev!20
parents cb66a789 565e10ac
......@@ -10,6 +10,7 @@ from .harmonic_transform_operator import HarmonicTransformOperator
from .inversion_enabler import InversionEnabler
from .laplace_operator import LaplaceOperator
from .linear_operator import LinearOperator
from .mask_operator import MaskOperator
from .multi_adaptor import MultiAdaptor
from .power_distributor import PowerDistributor
from .qht_operator import QHTOperator
......@@ -23,7 +24,7 @@ from .symmetrizing_operator import SymmetrizingOperator
__all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"DiagonalOperator", "HarmonicTransformOperator", "FFTOperator",
"FFTSmoothingOperator", "GeometryRemover",
"FFTSmoothingOperator", "GeometryRemover", "MaskOperator",
"LaplaceOperator", "SmoothnessOperator", "PowerDistributor",
"InversionEnabler", "SandwichOperator", "SamplingEnabler",
"DOFDistributor", "SelectionOperator", "MultiAdaptor",
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2018 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..domain_tuple import DomainTuple
from import UnstructuredDomain
from ..field import Field
from .linear_operator import LinearOperator
class MaskOperator(LinearOperator):
def __init__(self, mask):
if not isinstance(mask, Field):
raise TypeError
self._domain = DomainTuple.make(mask.domain)
self._mask = np.logical_not(mask.to_global_data())
self._target = DomainTuple.make(UnstructuredDomain(self._mask.sum()))
def data_indices(self):
if len(self.domain.shape) == 1:
return np.arange(self.domain.shape[0])[self._mask]
if len(self.domain.shape) == 2:
return np.indices(self.domain.shape).transpose((1, 2, 0))[self._mask]
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
res = x.to_global_data()[self._mask]
return Field.from_global_data(, res)
x = x.to_global_data()
res = np.empty(self.domain.shape, x.dtype)
res[self._mask] = x
res[~self._mask] = 0
return Field.from_global_data(self.domain, res)
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def domain(self):
return self._domain
def target(self):
return self._target
......@@ -61,6 +61,17 @@ class Consistency_Tests(unittest.TestCase):
op = ift.HarmonicTransformOperator(sp)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_p_spaces, [np.float64, np.complex128]))
def testMask(self, sp, dtype):
# Create mask
f = ift.from_random('normal', sp).to_global_data()
mask = np.zeros_like(f)
mask[f > 0] = 1
mask = ift.Field.from_global_data(sp, mask)
# Test MaskOperator
op = ift.MaskOperator(mask)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces+_p_spaces, [np.float64, np.complex128]))
def testDiagonal(self, sp, dtype):
op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment