Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
4f8fa350
Commit
4f8fa350
authored
Jan 17, 2019
by
Martin Reinecke
Browse files
Merge branch 'cosmetics' into 'NIFTy_5'
Functionatliy of Value Inserter See merge request ift/nifty-dev!187
parents
c5f608f1
caf95cad
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/value_inserter.py
View file @
4f8fa350
...
...
@@ -15,6 +15,9 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
functools
import
reduce
from
operator
import
mul
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
...
...
@@ -25,7 +28,7 @@ from .linear_operator import LinearOperator
class
ValueInserter
(
LinearOperator
):
"""Inserts one value into a field which is
zero
otherwise.
"""Inserts one value into a field which is
constant
otherwise.
Parameters
----------
...
...
@@ -33,25 +36,37 @@ class ValueInserter(LinearOperator):
index : iterable of int
The index of the target into which the value of the domain shall be
inserted.
default_value : float
Constant value which is inserted everywhere where the input operator
is not inserted. Default is 0.
"""
def
__init__
(
self
,
target
,
index
):
def
__init__
(
self
,
target
,
index
,
default_value
=
0.
):
self
.
_domain
=
makeDomain
(
UnstructuredDomain
(
1
))
self
.
_target
=
DomainTuple
.
make
(
target
)
# Type and value checks
index
=
tuple
(
index
)
if
not
all
([
isinstance
(
n
,
int
)
and
n
>=
0
and
n
<
self
.
target
.
shape
[
i
]
for
i
,
n
in
enumerate
(
index
)]):
if
not
all
([
isinstance
(
n
,
int
)
and
n
>=
0
and
n
<
self
.
target
.
shape
[
i
]
for
i
,
n
in
enumerate
(
index
)
]):
raise
TypeError
if
not
len
(
index
)
==
len
(
self
.
target
.
shape
):
raise
ValueError
np
.
empty
(
self
.
target
.
shape
)[
index
]
self
.
_index
=
index
self
.
_dv
=
float
(
default_value
)
self
.
_dvsum
=
self
.
_dv
*
(
reduce
(
mul
,
self
.
target
.
shape
)
-
1
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
# Check whether index is in bounds
np
.
empty
(
self
.
target
.
shape
)[
self
.
_index
]
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
x
=
x
.
to_global_data
()
if
mode
==
self
.
TIMES
:
res
=
np
.
zeros
(
self
.
target
.
shape
,
dtype
=
x
.
dtype
)
res
=
np
.
full
(
self
.
target
.
shape
,
self
.
_dv
,
dtype
=
x
.
dtype
)
res
[
self
.
_index
]
=
x
else
:
res
=
np
.
full
((
1
,),
x
[
self
.
_index
],
dtype
=
x
.
dtype
)
res
=
np
.
full
((
1
,),
x
[
self
.
_index
]
+
self
.
_dvsum
,
dtype
=
x
.
dtype
)
return
Field
.
from_global_data
(
self
.
_tgt
(
mode
),
res
)
test/test_operators/test_value_inserter.py
View file @
4f8fa350
...
...
@@ -17,7 +17,7 @@
import
numpy
as
np
import
pytest
from
numpy.testing
import
assert_
from
numpy.testing
import
assert_
allclose
import
nifty5
as
ift
...
...
@@ -37,5 +37,17 @@ def test_value_inserter(sp, seed):
f
=
ift
.
from_random
(
'normal'
,
ift
.
UnstructuredDomain
((
1
,)))
inp
=
f
.
to_global_data
()[
0
]
ret
=
op
(
f
).
to_global_data
()
assert_
(
ret
[
ind
]
==
inp
)
assert_
(
np
.
sum
(
ret
)
==
inp
)
assert_allclose
(
ret
[
ind
],
inp
)
assert_allclose
(
np
.
sum
(
ret
),
inp
)
def
test_value_inserter_nonzero
():
sp
=
ift
.
RGSpace
(
4
)
ind
=
(
1
,)
default
=
1.24
op
=
ift
.
ValueInserter
(
sp
,
ind
,
default
)
f
=
ift
.
from_random
(
'normal'
,
ift
.
UnstructuredDomain
((
1
,)))
inp
=
f
.
to_global_data
()[
0
]
ret
=
op
(
f
).
to_global_data
()
assert_allclose
(
ret
[
ind
],
inp
)
assert_allclose
(
np
.
sum
(
ret
),
inp
+
3
*
default
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment