Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
f4b1c9d7
Commit
f4b1c9d7
authored
Aug 27, 2018
by
Philipp Arras
Browse files
Refactor DomainTupleFieldInserter
parent
f8d22093
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/domain_tuple_field_inserter.py
View file @
f4b1c9d7
...
...
@@ -27,43 +27,32 @@ from .linear_operator import LinearOperator
class
DomainTupleFieldInserter
(
LinearOperator
):
def
__init__
(
self
,
domain
,
new_space
,
ind
,
infront
=
False
):
def
__init__
(
self
,
domain
,
new_space
,
ind
ex
,
position
):
'''Writes the content of a field into one slice of a DomainTuple.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
new_space : Domain, tuple of Domain or DomainTuple
ind : Integer
Index of the same space as new_space
infront : Boolean
If true, the new domain is added in the beginning of the
DomainTuple. Otherwise it is added at the end.
index : Integer
Position at which new_space shall be added to domain.
position : tuple
Slice in new_space at which the field shall be inserted.
'''
# FIXME Add assertions
self
.
_domain
=
DomainTuple
.
make
(
domain
)
if
infront
:
self
.
_target
=
DomainTuple
.
make
([
new_space
]
+
list
(
self
.
domain
))
else
:
self
.
_target
=
DomainTuple
.
make
(
list
(
self
.
domain
)
+
[
new_space
])
self
.
_infront
=
infront
tgt
=
list
(
self
.
domain
)
tgt
.
insert
(
index
,
new_space
)
self
.
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_ind
=
ind
self
.
_slc
=
(
slice
(
None
),)
*
index
+
position
+
(
slice
(
None
),)
*
(
len
(
self
.
domain
.
shape
)
-
index
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
res
=
np
.
zeros
(
self
.
target
.
shape
,
dtype
=
x
.
dtype
)
if
self
.
_infront
:
res
[
self
.
_ind
]
=
x
.
to_global_data
()
else
:
res
[...,
self
.
_ind
]
=
x
.
to_global_data
()
res
[
self
.
_slc
]
=
x
.
to_global_data
()
return
Field
.
from_global_data
(
self
.
target
,
res
)
else
:
if
self
.
_infront
:
return
Field
.
from_global_data
(
self
.
domain
,
x
.
to_global_data
()[
self
.
_ind
])
else
:
return
Field
.
from_global_data
(
self
.
domain
,
x
.
to_global_data
()[...,
self
.
_ind
])
return
Field
.
from_global_data
(
self
.
domain
,
x
.
to_global_data
()[
self
.
_slc
])
test/test_operators/test_adjoint.py
View file @
f4b1c9d7
...
...
@@ -194,13 +194,12 @@ class Consistency_Tests(unittest.TestCase):
op
=
ift
.
ContractionOperator
(
dom
,
spaces
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
([
True
,
False
]))
def
testDomainTupleFieldInserter
(
self
,
infront
):
def
testDomainTupleFieldInserter
(
self
):
domain
=
ift
.
DomainTuple
.
make
((
ift
.
UnstructuredDomain
(
12
),
ift
.
RGSpace
([
4
,
22
])))
new_space
=
ift
.
UnstructuredDomain
(
7
)
ind
=
5
op
=
ift
.
DomainTupleFieldInserter
(
domain
,
new_space
,
ind
,
infront
)
pos
=
(
5
,)
op
=
ift
.
DomainTupleFieldInserter
(
domain
,
new_space
,
0
,
pos
)
ift
.
extra
.
consistency_check
(
op
)
@
expand
(
product
([
0
,
2
],
[
np
.
float64
,
np
.
complex128
]))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment