diff --git a/nifty5/operators/value_inserter.py b/nifty5/operators/value_inserter.py index 1b6fc18061a38075d4fca61f2926f759de289616..db568ecf9883b20e98bb1747bb0c2967789234d5 100644 --- a/nifty5/operators/value_inserter.py +++ b/nifty5/operators/value_inserter.py @@ -15,9 +15,6 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. -from functools import reduce -from operator import mul - import numpy as np from ..domain_tuple import DomainTuple @@ -28,8 +25,7 @@ from .linear_operator import LinearOperator class ValueInserter(LinearOperator): - # FIXME THIS IS NOT A LINEAR OPERATOR - """Inserts one value into a field which is constant otherwise. + """Inserts one value into a field which is zero otherwise. Parameters ---------- @@ -37,16 +33,11 @@ class ValueInserter(LinearOperator): index : iterable of int The index of the target into which the value of the domain shall be inserted. - default_value : float - Constant value which is inserted everywhere where the input operator - is not inserted. Default is 0. """ - def __init__(self, target, index, default_value=0.): + def __init__(self, target, index): self._domain = makeDomain(UnstructuredDomain(1)) self._target = DomainTuple.make(target) - - # Type and value checks index = tuple(index) if not all([ isinstance(n, int) and n >= 0 and n < self.target.shape[i] @@ -55,19 +46,17 @@ class ValueInserter(LinearOperator): raise TypeError if not len(index) == len(self.target.shape): raise ValueError - np.empty(self.target.shape)[index] - self._index = index - self._dv = float(default_value) - self._dvsum = self._dv*(reduce(mul, self.target.shape) - 1) self._capability = self.TIMES | self.ADJOINT_TIMES + # Check whether index is in bounds + np.empty(self.target.shape)[self._index] def apply(self, x, mode): self._check_input(x, mode) x = x.to_global_data() if mode == self.TIMES: - res = np.full(self.target.shape, self._dv, dtype=x.dtype) + res = np.zeros(self.target.shape, dtype=x.dtype) res[self._index] = x else: - res = np.full((1,), x[self._index] + self._dvsum, dtype=x.dtype) + res = np.full((1,), x[self._index], dtype=x.dtype) return Field.from_global_data(self._tgt(mode), res) diff --git a/test/test_operators/test_value_inserter.py b/test/test_operators/test_value_inserter.py index c84efb159c3cac32df96250e7743199e48f71de5..2f2e34253829129aac4b94171d8b024468961021 100644 --- a/test/test_operators/test_value_inserter.py +++ b/test/test_operators/test_value_inserter.py @@ -17,7 +17,7 @@ import numpy as np import pytest -from numpy.testing import assert_allclose +from numpy.testing import assert_ import nifty5 as ift @@ -37,17 +37,5 @@ def test_value_inserter(sp, seed): f = ift.from_random('normal', ift.UnstructuredDomain((1,))) inp = f.to_global_data()[0] ret = op(f).to_global_data() - assert_allclose(ret[ind], inp) - assert_allclose(np.sum(ret), inp) - - -def test_value_inserter_nonzero(): - sp = ift.RGSpace(4) - ind = (1,) - default = 1.24 - op = ift.ValueInserter(sp, ind, default) - f = ift.from_random('normal', ift.UnstructuredDomain((1,))) - inp = f.to_global_data()[0] - ret = op(f).to_global_data() - assert_allclose(ret[ind], inp) - assert_allclose(np.sum(ret), inp + 3*default) + assert_(ret[ind] == inp) + assert_(np.sum(ret) == inp)