Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
0327ffdc
Commit
0327ffdc
authored
Jan 21, 2019
by
Philipp Arras
Browse files
Adapt interface of DomainTupleFieldInserter
parent
4f79d69c
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/domain_tuple_field_inserter.py
View file @
0327ffdc
...
...
@@ -23,31 +23,38 @@ from .linear_operator import LinearOperator
class
DomainTupleFieldInserter
(
LinearOperator
):
"""Writes the content of a :class:`Field` into one slice of a :class:`DomainTuple`.
"""Writes the content of a :class:`Field` into one slice of a
:class:`DomainTuple`.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple
index : Integer
Index at which new_space shall be added to domain.
position : tuple
Slice in new_space in which the input field shall be written into.
target : Domain, tuple of Domain or DomainTuple
space : int
The index of the sub-domain which is inserted.
index : tuple
Slice in new sub-domain in which the input field shall be written into.
"""
def
__init__
(
self
,
domain
,
new_space
,
index
,
position
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
tgt
=
list
(
self
.
domain
)
tgt
.
insert
(
index
,
new_space
)
self
.
_target
=
DomainTuple
.
make
(
tgt
)
def
__init__
(
self
,
target
,
space
,
pos
):
if
not
space
<=
len
(
target
)
or
space
<
0
:
raise
ValueError
self
.
_target
=
DomainTuple
.
make
(
target
)
dom
=
list
(
self
.
target
)
dom
.
pop
(
space
)
self
.
_domain
=
DomainTuple
.
make
(
dom
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
fst_dims
=
sum
(
len
(
dd
.
shape
)
for
dd
in
self
.
domain
[:
index
])
new_space
=
target
[
space
]
nshp
=
new_space
.
shape
if
len
(
position
)
!=
len
(
nshp
):
fst_dims
=
sum
(
len
(
dd
.
shape
)
for
dd
in
self
.
target
[:
space
])
if
len
(
pos
)
!=
len
(
nshp
):
raise
ValueError
(
"shape mismatch between new_space and position"
)
for
s
,
p
in
zip
(
nshp
,
pos
ition
):
for
s
,
p
in
zip
(
nshp
,
pos
):
if
p
<
0
or
p
>=
s
:
raise
ValueError
(
"bad position value"
)
self
.
_slc
=
(
slice
(
None
),)
*
fst_dims
+
position
self
.
_slc
=
(
slice
(
None
),)
*
fst_dims
+
pos
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
test/test_operators/test_adjoint.py
View file @
0327ffdc
...
...
@@ -189,11 +189,10 @@ def testContractionOperator(spaces, wgt, dtype):
def
testDomainTupleFieldInserter
():
domain
=
ift
.
DomainTuple
.
make
((
ift
.
UnstructuredDomain
(
12
),
target
=
ift
.
DomainTuple
.
make
((
ift
.
UnstructuredDomain
([
3
,
2
]),
ift
.
UnstructuredDomain
(
7
),
ift
.
RGSpace
([
4
,
22
])))
new_space
=
ift
.
UnstructuredDomain
(
7
)
pos
=
(
5
,)
op
=
ift
.
DomainTupleFieldInserter
(
domain
,
new_space
,
0
,
pos
)
op
=
ift
.
DomainTupleFieldInserter
(
target
,
1
,
(
5
,))
ift
.
extra
.
consistency_check
(
op
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment