Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
caf95cad
Commit
caf95cad
authored
Jan 17, 2019
by
Philipp Arras
Browse files
Add default value to ValueInserter
parent
8b210cbd
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/value_inserter.py
View file @
caf95cad
...
...
@@ -15,6 +15,9 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
functools
import
reduce
from
operator
import
mul
import
numpy
as
np
from
..domain_tuple
import
DomainTuple
...
...
@@ -25,7 +28,7 @@ from .linear_operator import LinearOperator
class
ValueInserter
(
LinearOperator
):
"""Inserts one value into a field which is
zero
otherwise.
"""Inserts one value into a field which is
constant
otherwise.
Parameters
----------
...
...
@@ -33,11 +36,16 @@ class ValueInserter(LinearOperator):
index : iterable of int
The index of the target into which the value of the domain shall be
inserted.
default_value : float
Constant value which is inserted everywhere where the input operator
is not inserted. Default is 0.
"""
def
__init__
(
self
,
target
,
index
):
def
__init__
(
self
,
target
,
index
,
default_value
=
0.
):
self
.
_domain
=
makeDomain
(
UnstructuredDomain
(
1
))
self
.
_target
=
DomainTuple
.
make
(
target
)
# Type and value checks
index
=
tuple
(
index
)
if
not
all
([
isinstance
(
n
,
int
)
and
n
>=
0
and
n
<
self
.
target
.
shape
[
i
]
...
...
@@ -46,17 +54,19 @@ class ValueInserter(LinearOperator):
raise
TypeError
if
not
len
(
index
)
==
len
(
self
.
target
.
shape
):
raise
ValueError
np
.
empty
(
self
.
target
.
shape
)[
index
]
self
.
_index
=
index
self
.
_dv
=
float
(
default_value
)
self
.
_dvsum
=
self
.
_dv
*
(
reduce
(
mul
,
self
.
target
.
shape
)
-
1
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
# Check whether index is in bounds
np
.
empty
(
self
.
target
.
shape
)[
self
.
_index
]
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
x
=
x
.
to_global_data
()
if
mode
==
self
.
TIMES
:
res
=
np
.
zeros
(
self
.
target
.
shape
,
dtype
=
x
.
dtype
)
res
=
np
.
full
(
self
.
target
.
shape
,
self
.
_dv
,
dtype
=
x
.
dtype
)
res
[
self
.
_index
]
=
x
else
:
res
=
np
.
full
((
1
,),
x
[
self
.
_index
],
dtype
=
x
.
dtype
)
res
=
np
.
full
((
1
,),
x
[
self
.
_index
]
+
self
.
_dvsum
,
dtype
=
x
.
dtype
)
return
Field
.
from_global_data
(
self
.
_tgt
(
mode
),
res
)
test/test_operators/test_value_inserter.py
View file @
caf95cad
...
...
@@ -17,7 +17,7 @@
import
numpy
as
np
import
pytest
from
numpy.testing
import
assert_
from
numpy.testing
import
assert_
allclose
import
nifty5
as
ift
...
...
@@ -37,5 +37,17 @@ def test_value_inserter(sp, seed):
f
=
ift
.
from_random
(
'normal'
,
ift
.
UnstructuredDomain
((
1
,)))
inp
=
f
.
to_global_data
()[
0
]
ret
=
op
(
f
).
to_global_data
()
assert_
(
ret
[
ind
]
==
inp
)
assert_
(
np
.
sum
(
ret
)
==
inp
)
assert_allclose
(
ret
[
ind
],
inp
)
assert_allclose
(
np
.
sum
(
ret
),
inp
)
def
test_value_inserter_nonzero
():
sp
=
ift
.
RGSpace
(
4
)
ind
=
(
1
,)
default
=
1.24
op
=
ift
.
ValueInserter
(
sp
,
ind
,
default
)
f
=
ift
.
from_random
(
'normal'
,
ift
.
UnstructuredDomain
((
1
,)))
inp
=
f
.
to_global_data
()[
0
]
ret
=
op
(
f
).
to_global_data
()
assert_allclose
(
ret
[
ind
],
inp
)
assert_allclose
(
np
.
sum
(
ret
),
inp
+
3
*
default
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment