Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
e8ba6f2e
Commit
e8ba6f2e
authored
Mar 24, 2021
by
Martin Reinecke
Browse files
Merge branch 'fix_slice_operator' into 'NIFTy_7'
Fix slice operator See merge request
!606
parents
f10f5f20
21854798
Pipeline
#96746
passed with stages
in 12 minutes and 8 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/selection_operators.py
View file @
e8ba6f2e
...
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-201
9
Max-Planck-Society
# Copyright(C) 2013-20
2
1 Max-Planck-Society
# Authors: Gordian Edenhofer
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
@@ -37,9 +37,11 @@ class SliceOperator(LinearOperator):
----------
domain : Domain, DomainTuple or tuple of Domain
The operator's input domain.
new_shape : tuple of integers or None
new_shape : tuple of
tuples or
integers
,
or None
The shape of the target domain with None indicating to copy the shape
of the original domain for this axis.
of the original domain for this axis. For example ((10, 5), 100) for a
DomainTuple with two entires, the first having shape (10, 5) and the
second having shape 100
center : bool, optional
Whether to center the slice that is selected in the input field.
preserve_dist: bool, optional
...
...
@@ -47,19 +49,27 @@ class SliceOperator(LinearOperator):
"""
def
__init__
(
self
,
domain
,
new_shape
,
center
=
False
,
preserve_dist
=
True
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
if
len
(
new_shape
)
!=
len
(
self
.
_domain
.
shape
):
if
len
(
new_shape
)
!=
len
(
self
.
_domain
):
ve
=
(
f
"shape (
{
new_shape
}
) is incompatible with the shape of the"
f
" domain (
{
self
.
_domain
.
shape
}
)"
)
raise
ValueError
(
ve
)
for
i
,
shape
in
enumerate
(
new_shape
):
if
len
(
np
.
atleast_1d
(
shape
))
!=
len
(
self
.
_domain
[
i
].
shape
):
ve
=
(
f
"shape of subspace (
{
i
}
) is incompatible with the domain"
)
raise
ValueError
(
ve
)
tgt
=
[]
slc_by_ax
=
[]
for
i
,
d
in
enumerate
(
self
.
_domain
):
if
new_shape
[
i
]
is
None
or
self
.
_domain
.
shape
[
i
]
==
new_shape
[
i
]:
if
new_shape
[
i
]
is
None
or
np
.
all
(
np
.
array
(
self
.
_domain
.
shape
[
i
])
==
np
.
array
(
new_shape
[
i
])
):
tgt
+=
[
d
]
elif
new_shape
[
i
]
<
self
.
_domain
.
shape
[
i
]
:
elif
np
.
all
(
np
.
array
(
new_shape
[
i
]
)
<
=
np
.
array
(
d
.
shape
))
:
dom_kw
=
dict
()
if
isinstance
(
d
,
RGSpace
):
if
preserve_dist
:
...
...
@@ -78,14 +88,15 @@ class SliceOperator(LinearOperator):
raise
ValueError
(
ve
)
if
center
:
slc_start
=
np
.
floor
(
(
self
.
_domain
.
shape
[
i
]
-
n
ew_shape
[
i
])
/
2.
).
astype
(
int
)
slc_end
=
slc_start
+
new_shape
[
i
]
for
j
,
n_pix
in
enumerate
(
np
.
atleast_1d
(
new_shape
[
i
])):
slc_start
=
np
.
floor
((
d
.
shape
[
j
]
-
n
_pix
)
/
2.
).
astype
(
int
)
slc_end
=
slc_start
+
n_pix
slc_by_ax
+=
[
slice
(
slc_start
,
slc_end
)
]
else
:
slc_start
=
0
slc_end
=
new_shape
[
i
]
slc_by_ax
+=
[
slice
(
slc_start
,
slc_end
)]
for
n_pix
in
np
.
atleast_1d
(
new_shape
[
i
]):
slc_start
=
0
slc_end
=
n_pix
slc_by_ax
+=
[
slice
(
slc_start
,
slc_end
)]
self
.
_slc_by_ax
=
tuple
(
slc_by_ax
)
self
.
_target
=
DomainTuple
.
make
(
tgt
)
...
...
@@ -102,8 +113,10 @@ class SliceOperator(LinearOperator):
return
Field
.
from_raw
(
self
.
domain
,
res
)
def
__str__
(
self
):
ss
=
(
f
"
{
self
.
__class__
.
__name__
}
"
f
"(
{
self
.
domain
.
shape
}
->
{
self
.
target
.
shape
}
)"
)
ss
=
(
f
"
{
self
.
__class__
.
__name__
}
"
f
"(
{
self
.
domain
.
shape
}
->
{
self
.
target
.
shape
}
)"
)
return
ss
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment