Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
0de313ef
Commit
0de313ef
authored
Jun 29, 2018
by
Martin Reinecke
Browse files
Merge branch 'mask_operator' into 'NIFTy_5'
Add mask operator See merge request ift/nifty-dev!20
parents
cb66a789
565e10ac
Changes
3
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/__init__.py
View file @
0de313ef
...
...
@@ -10,6 +10,7 @@ from .harmonic_transform_operator import HarmonicTransformOperator
from
.inversion_enabler
import
InversionEnabler
from
.laplace_operator
import
LaplaceOperator
from
.linear_operator
import
LinearOperator
from
.mask_operator
import
MaskOperator
from
.multi_adaptor
import
MultiAdaptor
from
.power_distributor
import
PowerDistributor
from
.qht_operator
import
QHTOperator
...
...
@@ -23,7 +24,7 @@ from .symmetrizing_operator import SymmetrizingOperator
__all__
=
[
"LinearOperator"
,
"EndomorphicOperator"
,
"ScalingOperator"
,
"DiagonalOperator"
,
"HarmonicTransformOperator"
,
"FFTOperator"
,
"FFTSmoothingOperator"
,
"GeometryRemover"
,
"FFTSmoothingOperator"
,
"GeometryRemover"
,
"MaskOperator"
,
"LaplaceOperator"
,
"SmoothnessOperator"
,
"PowerDistributor"
,
"InversionEnabler"
,
"SandwichOperator"
,
"SamplingEnabler"
,
"DOFDistributor"
,
"SelectionOperator"
,
"MultiAdaptor"
,
...
...
nifty5/operators/mask_operator.py
0 → 100644
View file @
0de313ef
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
..domains.unstructured_domain
import
UnstructuredDomain
from
..field
import
Field
from
.linear_operator
import
LinearOperator
class
MaskOperator
(
LinearOperator
):
def
__init__
(
self
,
mask
):
if
not
isinstance
(
mask
,
Field
):
raise
TypeError
self
.
_domain
=
DomainTuple
.
make
(
mask
.
domain
)
self
.
_mask
=
np
.
logical_not
(
mask
.
to_global_data
())
self
.
_target
=
DomainTuple
.
make
(
UnstructuredDomain
(
self
.
_mask
.
sum
()))
def
data_indices
(
self
):
if
len
(
self
.
domain
.
shape
)
==
1
:
return
np
.
arange
(
self
.
domain
.
shape
[
0
])[
self
.
_mask
]
if
len
(
self
.
domain
.
shape
)
==
2
:
return
np
.
indices
(
self
.
domain
.
shape
).
transpose
((
1
,
2
,
0
))[
self
.
_mask
]
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
res
=
x
.
to_global_data
()[
self
.
_mask
]
return
Field
.
from_global_data
(
self
.
target
,
res
)
x
=
x
.
to_global_data
()
res
=
np
.
empty
(
self
.
domain
.
shape
,
x
.
dtype
)
res
[
self
.
_mask
]
=
x
res
[
~
self
.
_mask
]
=
0
return
Field
.
from_global_data
(
self
.
domain
,
res
)
@
property
def
capability
(
self
):
return
self
.
TIMES
|
self
.
ADJOINT_TIMES
@
property
def
domain
(
self
):
return
self
.
_domain
@
property
def
target
(
self
):
return
self
.
_target
test/test_operators/test_adjoint.py
View file @
0de313ef
...
...
@@ -61,6 +61,17 @@ class Consistency_Tests(unittest.TestCase):
op
=
ift
.
HarmonicTransformOperator
(
sp
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_p_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testMask
(
self
,
sp
,
dtype
):
# Create mask
f
=
ift
.
from_random
(
'normal'
,
sp
).
to_global_data
()
mask
=
np
.
zeros_like
(
f
)
mask
[
f
>
0
]
=
1
mask
=
ift
.
Field
.
from_global_data
(
sp
,
mask
)
# Test MaskOperator
op
=
ift
.
MaskOperator
(
mask
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
+
_p_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testDiagonal
(
self
,
sp
,
dtype
):
op
=
ift
.
DiagonalOperator
(
ift
.
Field
.
from_random
(
"normal"
,
sp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment