Commit 4f8fa350 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'cosmetics' into 'NIFTy_5'

Functionatliy of Value Inserter

See merge request ift/nifty-dev!187
parents c5f608f1 caf95cad
......@@ -15,6 +15,9 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from operator import mul
import numpy as np
from ..domain_tuple import DomainTuple
......@@ -25,7 +28,7 @@ from .linear_operator import LinearOperator
class ValueInserter(LinearOperator):
"""Inserts one value into a field which is zero otherwise.
"""Inserts one value into a field which is constant otherwise.
......@@ -33,25 +36,37 @@ class ValueInserter(LinearOperator):
index : iterable of int
The index of the target into which the value of the domain shall be
default_value : float
Constant value which is inserted everywhere where the input operator
is not inserted. Default is 0.
def __init__(self, target, index):
def __init__(self, target, index, default_value=0.):
self._domain = makeDomain(UnstructuredDomain(1))
self._target = DomainTuple.make(target)
# Type and value checks
index = tuple(index)
if not all([isinstance(n, int) and n>=0 and n<[i] for i, n in enumerate(index)]):
if not all([
isinstance(n, int) and n >= 0 and n <[i]
for i, n in enumerate(index)
raise TypeError
if not len(index) == len(
raise ValueError
self._index = index
self._dv = float(default_value)
self._dvsum = self._dv*(reduce(mul, - 1)
self._capability = self.TIMES | self.ADJOINT_TIMES
# Check whether index is in bounds
def apply(self, x, mode):
self._check_input(x, mode)
x = x.to_global_data()
if mode == self.TIMES:
res = np.zeros(, dtype=x.dtype)
res = np.full(, self._dv, dtype=x.dtype)
res[self._index] = x
res = np.full((1,), x[self._index], dtype=x.dtype)
res = np.full((1,), x[self._index] + self._dvsum, dtype=x.dtype)
return Field.from_global_data(self._tgt(mode), res)
......@@ -17,7 +17,7 @@
import numpy as np
import pytest
from numpy.testing import assert_
from numpy.testing import assert_allclose
import nifty5 as ift
......@@ -37,5 +37,17 @@ def test_value_inserter(sp, seed):
f = ift.from_random('normal', ift.UnstructuredDomain((1,)))
inp = f.to_global_data()[0]
ret = op(f).to_global_data()
assert_(ret[ind] == inp)
assert_(np.sum(ret) == inp)
assert_allclose(ret[ind], inp)
assert_allclose(np.sum(ret), inp)
def test_value_inserter_nonzero():
sp = ift.RGSpace(4)
ind = (1,)
default = 1.24
op = ift.ValueInserter(sp, ind, default)
f = ift.from_random('normal', ift.UnstructuredDomain((1,)))
inp = f.to_global_data()[0]
ret = op(f).to_global_data()
assert_allclose(ret[ind], inp)
assert_allclose(np.sum(ret), inp + 3*default)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment