Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
38bac55e
Commit
38bac55e
authored
May 18, 2020
by
Martin Reinecke
Browse files
Merge branch 'nifty6_select' into 'NIFTy_6'
Introduce a slicing and a splitting operator See merge request
!461
parents
bd57f855
1e4e8d70
Pipeline
#75094
passed with stages
in 35 minutes and 6 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/__init__.py
View file @
38bac55e
...
...
@@ -38,6 +38,7 @@ from .operators.regridding_operator import RegriddingOperator
from
.operators.sampling_enabler
import
SamplingEnabler
,
SamplingDtypeSetter
from
.operators.sandwich_operator
import
SandwichOperator
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.selection_operators
import
SliceOperator
,
SplitOperator
from
.operators.block_diagonal_operator
import
BlockDiagonalOperator
from
.operators.outer_product_operator
import
OuterProduct
from
.operators.simple_linear_operators
import
(
...
...
nifty6/operators/selection_operators.py
0 → 100644
View file @
38bac55e
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
from
..domains.rg_space
import
RGSpace
from
..domains.unstructured_domain
import
UnstructuredDomain
from
..field
import
Field
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
.linear_operator
import
LinearOperator
class
SliceOperator
(
LinearOperator
):
"""Geometry preserving mask operator
Takes a field, slices it into the desired shape and returns the values of
the field in the sliced domain all while preserving the original distances.
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
tgt_shape : tuple of integers or None
The shape of the target domain with None indicating to copy the shape
of the original domain for this axis.
center : bool, optional
Whether to center the slice that is selected in the input field.
preserve_dist: bool, optional
Whether to preserve the distance of the input field.
"""
def
__init__
(
self
,
domain
,
tgt_shape
,
center
=
False
,
preserve_dist
=
True
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
if
len
(
tgt_shape
)
!=
len
(
self
.
_domain
.
shape
):
ve
=
(
f
"shape (
{
tgt_shape
}
) is incompatible with the shape of the"
f
" domain (
{
self
.
_domain
.
shape
}
)"
)
raise
ValueError
(
ve
)
tgt
=
[]
slc_by_ax
=
[]
for
i
,
d
in
enumerate
(
self
.
_domain
):
if
tgt_shape
[
i
]
is
None
or
self
.
_domain
.
shape
[
i
]
==
tgt_shape
[
i
]:
tgt
+=
[
d
]
elif
tgt_shape
[
i
]
<
self
.
_domain
.
shape
[
i
]:
dom_kw
=
dict
()
if
isinstance
(
d
,
RGSpace
):
if
preserve_dist
:
dom_kw
[
"distances"
]
=
d
.
distances
dom_kw
[
"harmonic"
]
=
d
.
harmonic
elif
not
isinstance
(
d
,
UnstructuredDomain
):
# Some domains like HPSpace or LMSPace can not be sliced
ve
=
f
"
{
d
.
__class__
.
__name__
}
can not be sliced"
raise
ValueError
(
ve
)
tgt
+=
[
d
.
__class__
(
tgt_shape
[
i
],
**
dom_kw
)]
else
:
ve
=
(
f
"domain axes (
{
d
}
) is smaller than the target shape"
f
"
{
tgt_shape
[
i
]
}
"
)
raise
ValueError
(
ve
)
if
center
:
slc_start
=
np
.
floor
(
(
self
.
_domain
.
shape
[
i
]
-
tgt_shape
[
i
])
/
2.
).
astype
(
int
)
slc_end
=
slc_start
+
tgt_shape
[
i
]
else
:
slc_start
=
0
slc_end
=
tgt_shape
[
i
]
slc_by_ax
+=
[
slice
(
slc_start
,
slc_end
)]
self
.
_slc_by_ax
=
tuple
(
slc_by_ax
)
self
.
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
x
=
x
.
val
if
mode
==
self
.
TIMES
:
res
=
x
[
self
.
_slc_by_ax
]
return
Field
.
from_raw
(
self
.
target
,
res
)
res
=
np
.
zeros
(
self
.
domain
.
shape
,
x
.
dtype
)
res
[
self
.
_slc_by_ax
]
=
x
return
Field
.
from_raw
(
self
.
domain
,
res
)
def
__str__
(
self
):
ss
=
(
f
"
{
self
.
__class__
.
__name__
}
"
f
"(
{
self
.
domain
.
shape
}
->
{
self
.
target
.
shape
}
)"
)
return
ss
class
SplitOperator
(
LinearOperator
):
"""Split a single field into a multi-field
Takes a field, selects the desired entries for each multi-field key and
puts the result into a multi-field. Along sliced axis, the domain will
be replaced by an UnstructuredDomain as no distance measures are preserved.
Note, slices may intersect, i.e. slices may reference the same input
multiple times if the `intersecting_slices` option is set. However, a
single field in the output may not contain the same part of the input more
than once.
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
slices_by_key : dict{key: tuple of integers or None}
The key-value pairs of which the values indicate the parts to be
selected. The result will be a multi-field with the given keys as
entries and the selected slices of the domain as values. `None`
indicates to select the whole input along this axis.
intersecting_slices : bool, optional
Tells the operator whether slices may contain intersections. If true,
the adjoint is constructed a little less efficiently. Set this
parameter to `False` to gain a little more efficiency.
"""
def
__init__
(
self
,
domain
,
slices_by_key
,
intersecting_slices
=
True
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_intersec_slc
=
intersecting_slices
tgt
=
dict
()
self
.
_k_slc
=
dict
()
for
k
,
slc
in
slices_by_key
.
items
():
if
len
(
slc
)
>
len
(
self
.
_domain
):
ve
=
f
"slice at key
{
k
!r}
has more dimensions than the input"
raise
ValueError
(
ve
)
k_tgt
=
[]
k_slc_by_ax
=
[]
for
i
,
d
in
enumerate
(
self
.
_domain
):
if
i
>=
len
(
slc
)
or
slc
[
i
]
is
None
or
(
isinstance
(
slc
[
i
],
slice
)
and
slc
[
i
]
==
slice
(
None
)
):
k_tgt
+=
[
d
]
k_slc_by_ax
+=
[
slice
(
None
)]
elif
isinstance
(
slc
[
i
],
slice
):
start
=
slc
[
i
].
start
if
slc
[
i
].
start
is
not
None
else
0
stop
=
slc
[
i
].
stop
if
slc
[
i
].
stop
is
not
None
else
d
.
size
step
=
slc
[
i
].
step
if
slc
[
i
].
step
is
not
None
else
1
frac
=
np
.
floor
((
stop
-
start
)
/
np
.
abs
(
step
))
k_tgt
+=
[
UnstructuredDomain
(
frac
.
astype
(
int
))]
k_slc_by_ax
+=
[
slc
[
i
]]
elif
isinstance
(
slc
[
i
],
np
.
ndarray
)
and
slc
[
i
].
dtype
is
np
.
dtype
(
bool
):
if
slc
[
i
].
size
!=
d
.
size
:
ve
=
(
"shape mismatch between desired slice {slc[i]}"
"and the shape of the domain {d.size}"
)
raise
ValueError
(
ve
)
k_tgt
+=
[
UnstructuredDomain
(
slc
[
i
].
sum
())]
k_slc_by_ax
+=
[
slc
[
i
]]
elif
isinstance
(
slc
[
i
],
(
tuple
,
list
,
np
.
ndarray
)):
k_tgt
+=
[
UnstructuredDomain
(
len
(
slc
[
i
]))]
k_slc_by_ax
+=
[
slc
[
i
]]
elif
isinstance
(
slc
[
i
],
int
):
k_slc_by_ax
+=
[
slc
[
i
]]
else
:
ve
=
f
"invalid type for specifying a slice; got
{
slc
[
i
]
}
"
raise
ValueError
(
ve
)
tgt
[
k
]
=
DomainTuple
.
make
(
k_tgt
)
self
.
_k_slc
[
k
]
=
tuple
(
k_slc_by_ax
)
self
.
_target
=
MultiDomain
.
make
(
tgt
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
x
=
x
.
val
if
mode
==
self
.
TIMES
:
res
=
dict
()
for
k
,
slc
in
self
.
_k_slc
.
items
():
res
[
k
]
=
x
[
slc
]
return
MultiField
.
from_raw
(
self
.
target
,
res
)
# Note, not-selected parts must be zero. Hence, using the quicker
# `np.empty` method is unfortunately not possible
res
=
np
.
zeros
(
self
.
domain
.
shape
,
tuple
(
x
.
values
())[
0
].
dtype
)
if
self
.
_intersec_slc
:
for
k
,
slc
in
self
.
_k_slc
.
items
():
# Mind the `+` here for coping with intersections
res
[
slc
]
+=
x
[
k
]
return
Field
.
from_raw
(
self
.
domain
,
res
)
for
k
,
slc
in
self
.
_k_slc
.
items
():
res
[
slc
]
=
x
[
k
]
return
Field
.
from_raw
(
self
.
domain
,
res
)
def
__str__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
_target
.
keys
()
!r}
<-"
test/test_operators/test_selection_operators.py
0 → 100644
View file @
38bac55e
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
pytest
from
numpy.testing
import
assert_allclose
,
assert_array_equal
from
nifty6.extra
import
consistency_check
import
numpy
as
np
import
nifty6
as
ift
from
..common
import
list2fixture
,
setup_function
,
teardown_function
pmp
=
pytest
.
mark
.
parametrize
# The test cases do not work on a multi-dimensional RGSpace yet
spaces
=
(
ift
.
UnstructuredDomain
(
4
),
ift
.
LMSpace
(
5
),
ift
.
GLSpace
(
4
),
)
space1
=
list2fixture
(
spaces
)
space2
=
list2fixture
(
spaces
)
dtype
=
list2fixture
([
np
.
float64
,
np
.
complex128
])
def
test_split_operator_first_axes_without_intersections
(
space1
,
space2
,
n_splits
=
3
):
rng
=
ift
.
random
.
current_rng
()
dom
=
ift
.
DomainTuple
.
make
((
space1
,
space2
))
orig_idx
=
np
.
arange
(
space1
.
shape
[
0
])
rng
.
shuffle
(
orig_idx
)
split_idx
=
np
.
array_split
(
orig_idx
,
n_splits
)
split
=
ift
.
SplitOperator
(
dom
,
{
f
"
{
i
:
06
d
}
"
:
(
si
,
)
for
i
,
si
in
enumerate
(
split_idx
)}
)
assert
consistency_check
(
split
)
is
None
r
=
ift
.
from_random
(
"normal"
,
dom
)
split_r
=
split
(
r
)
# This relies on the keys of the target domain either being in the order of
# insertion or being alphabetically sorted
for
idx
,
v
in
zip
(
split_idx
,
split_r
.
val
.
values
()):
assert_array_equal
(
r
.
val
[
idx
],
v
)
# Here, the adjoint must be the inverse as the field is split fully among
# the generated indices and without intersections.
assert_array_equal
(
split
.
adjoint
(
split_r
).
val
,
r
.
val
)
def
test_split_operator_first_axes_with_intersections
(
space1
,
space2
,
n_splits
=
3
):
rng
=
ift
.
random
.
current_rng
()
dom
=
ift
.
DomainTuple
.
make
((
space1
,
space2
))
orig_idx
=
np
.
arange
(
space1
.
shape
[
0
])
split_idx
=
[
rng
.
choice
(
orig_idx
,
rng
.
integers
(
1
,
space1
.
shape
[
0
]),
replace
=
False
)
for
_
in
range
(
n_splits
)
]
split
=
ift
.
SplitOperator
(
dom
,
{
f
"
{
i
:
06
d
}
"
:
(
si
,
)
for
i
,
si
in
enumerate
(
split_idx
)}
)
print
(
split_idx
)
assert
consistency_check
(
split
)
is
None
r
=
ift
.
from_random
(
"normal"
,
dom
)
split_r
=
split
(
r
)
# This relies on the keys of the target domain either being in the order of
# insertion or being alphabetically sorted
for
idx
,
v
in
zip
(
split_idx
,
split_r
.
val
.
values
()):
assert_array_equal
(
r
.
val
[
idx
],
v
)
r_diy
=
np
.
copy
(
r
.
val
)
unique_freq
=
np
.
unique
(
np
.
concatenate
(
split_idx
),
return_counts
=
True
)
# Null values that were not selected
r_diy
[
list
(
set
(
unique_freq
[
0
])
^
set
(
range
(
space1
.
shape
[
0
])))]
=
0.
for
idx
,
freq
in
zip
(
*
unique_freq
):
r_diy
[
idx
]
*=
freq
assert_allclose
(
split
.
adjoint
(
split_r
).
val
,
r_diy
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment