diff --git a/nifty5/operators/value_inserter.py b/nifty5/operators/value_inserter.py index db568ecf9883b20e98bb1747bb0c2967789234d5..ce1142da23d2ca031ea1f7b658ce2646009f6011 100644 --- a/nifty5/operators/value_inserter.py +++ b/nifty5/operators/value_inserter.py @@ -15,6 +15,9 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +from functools import reduce +from operator import mul + import numpy as np from ..domain_tuple import DomainTuple @@ -25,7 +28,7 @@ from .linear_operator import LinearOperator class ValueInserter(LinearOperator): - """Inserts one value into a field which is zero otherwise. + """Inserts one value into a field which is constant otherwise. Parameters ---------- @@ -33,11 +36,16 @@ class ValueInserter(LinearOperator): index : iterable of int The index of the target into which the value of the domain shall be inserted. + default_value : float + Constant value which is inserted everywhere where the input operator + is not inserted. Default is 0. """ - def __init__(self, target, index): + def __init__(self, target, index, default_value=0.): self._domain = makeDomain(UnstructuredDomain(1)) self._target = DomainTuple.make(target) + + # Type and value checks index = tuple(index) if not all([ isinstance(n, int) and n >= 0 and n < self.target.shape[i] @@ -46,17 +54,19 @@ class ValueInserter(LinearOperator): raise TypeError if not len(index) == len(self.target.shape): raise ValueError + np.empty(self.target.shape)[index] + self._index = index + self._dv = float(default_value) + self._dvsum = self._dv*(reduce(mul, self.target.shape) - 1) self._capability = self.TIMES | self.ADJOINT_TIMES - # Check whether index is in bounds - np.empty(self.target.shape)[self._index] def apply(self, x, mode): self._check_input(x, mode) x = x.to_global_data() if mode == self.TIMES: - res = np.zeros(self.target.shape, dtype=x.dtype) + res = np.full(self.target.shape, self._dv, dtype=x.dtype) res[self._index] = x else: - res = np.full((1,), x[self._index], dtype=x.dtype) + res = np.full((1,), x[self._index] + self._dvsum, dtype=x.dtype) return Field.from_global_data(self._tgt(mode), res) diff --git a/test/test_operators/test_value_inserter.py b/test/test_operators/test_value_inserter.py index 2f2e34253829129aac4b94171d8b024468961021..c84efb159c3cac32df96250e7743199e48f71de5 100644 --- a/test/test_operators/test_value_inserter.py +++ b/test/test_operators/test_value_inserter.py @@ -17,7 +17,7 @@ import numpy as np import pytest -from numpy.testing import assert_ +from numpy.testing import assert_allclose import nifty5 as ift @@ -37,5 +37,17 @@ def test_value_inserter(sp, seed): f = ift.from_random('normal', ift.UnstructuredDomain((1,))) inp = f.to_global_data()[0] ret = op(f).to_global_data() - assert_(ret[ind] == inp) - assert_(np.sum(ret) == inp) + assert_allclose(ret[ind], inp) + assert_allclose(np.sum(ret), inp) + + +def test_value_inserter_nonzero(): + sp = ift.RGSpace(4) + ind = (1,) + default = 1.24 + op = ift.ValueInserter(sp, ind, default) + f = ift.from_random('normal', ift.UnstructuredDomain((1,))) + inp = f.to_global_data()[0] + ret = op(f).to_global_data() + assert_allclose(ret[ind], inp) + assert_allclose(np.sum(ret), inp + 3*default)