Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
13c8fe93
Commit
13c8fe93
authored
Jun 28, 2018
by
Philipp Arras
Browse files
Implement MaskOperator
parent
4a29e1ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/mask_operator.py
View file @
13c8fe93
...
...
@@ -16,33 +16,35 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
..domains.unstructured_domain
import
UnstructuredDomain
from
..field
import
Field
from
..sugar
import
full
from
.linear_operator
import
LinearOperator
class
MaskOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
,
xy
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
# TODO Takes a field (boolean or 0/1)
# TODO Add MultiFields (output MultiField of unstructured domains)
def
__init__
(
self
,
mask
):
if
not
isinstance
(
mask
,
Field
):
raise
TypeError
assert
len
(
xy
.
shape
)
==
2
assert
xy
.
shape
[
1
]
==
2
self
.
_target
=
UnstructuredDomain
(
xy
.
shape
[
0
]
)
self
.
_domain
=
DomainTuple
.
make
(
mask
.
domain
)
self
.
_mask
=
np
.
logical_not
(
mask
.
to_global_data
())
self
.
_target
=
DomainTuple
.
make
(
UnstructuredDomain
(
self
.
_mask
.
sum
())
)
self
.
_xs
=
xy
.
T
[
0
]
self
.
_ys
=
xy
.
T
[
1
]
def
data_indices
(
self
):
return
np
.
indices
(
self
.
domain
.
shape
).
transpose
((
1
,
2
,
0
))[
self
.
_mask
]
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
res
=
x
.
val
[
self
.
_xs
,
self
.
_ys
]
res
=
x
.
to_global_data
()[
self
.
_mask
]
return
Field
(
self
.
target
,
res
)
res
=
full
(
self
.
domain
,
0.
)
res
[
self
.
_xs
,
self
.
_ys
]
=
x
.
val
x
=
x
.
to_global_data
()
res
=
np
.
empty
(
self
.
domain
.
shape
,
x
.
dtype
)
res
[
self
.
_mask
]
=
x
res
[
~
self
.
_mask
]
=
0
return
Field
(
self
.
domain
,
res
)
@
property
...
...
test/test_operators/test_adjoint.py
View file @
13c8fe93
...
...
@@ -61,6 +61,17 @@ class Consistency_Tests(unittest.TestCase):
op
=
ift
.
HarmonicTransformOperator
(
sp
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_p_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testMask
(
self
,
sp
,
dtype
):
# Create mask
f
=
ift
.
from_random
(
'normal'
,
sp
).
val
mask
=
np
.
zeros_like
(
f
)
mask
[
f
>
0
]
=
1
mask
=
ift
.
Field
(
sp
,
mask
)
# Test MaskOperator
op
=
ift
.
MaskOperator
(
mask
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
+
_p_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testDiagonal
(
self
,
sp
,
dtype
):
op
=
ift
.
DiagonalOperator
(
ift
.
Field
.
from_random
(
"normal"
,
sp
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment