Commit 03f31f69 authored by Philipp Arras's avatar Philipp Arras
Browse files

Revert "Add default value to ValueInserter"

This reverts commit caf95cad.
parent edf4e610
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from operator import mul
import numpy as np import numpy as np
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
...@@ -28,8 +25,7 @@ from .linear_operator import LinearOperator ...@@ -28,8 +25,7 @@ from .linear_operator import LinearOperator
class ValueInserter(LinearOperator): class ValueInserter(LinearOperator):
# FIXME THIS IS NOT A LINEAR OPERATOR """Inserts one value into a field which is zero otherwise.
"""Inserts one value into a field which is constant otherwise.
Parameters Parameters
---------- ----------
...@@ -37,16 +33,11 @@ class ValueInserter(LinearOperator): ...@@ -37,16 +33,11 @@ class ValueInserter(LinearOperator):
index : iterable of int index : iterable of int
The index of the target into which the value of the domain shall be The index of the target into which the value of the domain shall be
inserted. inserted.
default_value : float
Constant value which is inserted everywhere where the input operator
is not inserted. Default is 0.
""" """
def __init__(self, target, index, default_value=0.): def __init__(self, target, index):
self._domain = makeDomain(UnstructuredDomain(1)) self._domain = makeDomain(UnstructuredDomain(1))
self._target = DomainTuple.make(target) self._target = DomainTuple.make(target)
# Type and value checks
index = tuple(index) index = tuple(index)
if not all([ if not all([
isinstance(n, int) and n >= 0 and n < self.target.shape[i] isinstance(n, int) and n >= 0 and n < self.target.shape[i]
...@@ -55,19 +46,17 @@ class ValueInserter(LinearOperator): ...@@ -55,19 +46,17 @@ class ValueInserter(LinearOperator):
raise TypeError raise TypeError
if not len(index) == len(self.target.shape): if not len(index) == len(self.target.shape):
raise ValueError raise ValueError
np.empty(self.target.shape)[index]
self._index = index self._index = index
self._dv = float(default_value)
self._dvsum = self._dv*(reduce(mul, self.target.shape) - 1)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
# Check whether index is in bounds
np.empty(self.target.shape)[self._index]
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
x = x.to_global_data() x = x.to_global_data()
if mode == self.TIMES: if mode == self.TIMES:
res = np.full(self.target.shape, self._dv, dtype=x.dtype) res = np.zeros(self.target.shape, dtype=x.dtype)
res[self._index] = x res[self._index] = x
else: else:
res = np.full((1,), x[self._index] + self._dvsum, dtype=x.dtype) res = np.full((1,), x[self._index], dtype=x.dtype)
return Field.from_global_data(self._tgt(mode), res) return Field.from_global_data(self._tgt(mode), res)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_allclose from numpy.testing import assert_
import nifty5 as ift import nifty5 as ift
...@@ -37,17 +37,5 @@ def test_value_inserter(sp, seed): ...@@ -37,17 +37,5 @@ def test_value_inserter(sp, seed):
f = ift.from_random('normal', ift.UnstructuredDomain((1,))) f = ift.from_random('normal', ift.UnstructuredDomain((1,)))
inp = f.to_global_data()[0] inp = f.to_global_data()[0]
ret = op(f).to_global_data() ret = op(f).to_global_data()
assert_allclose(ret[ind], inp) assert_(ret[ind] == inp)
assert_allclose(np.sum(ret), inp) assert_(np.sum(ret) == inp)
def test_value_inserter_nonzero():
sp = ift.RGSpace(4)
ind = (1,)
default = 1.24
op = ift.ValueInserter(sp, ind, default)
f = ift.from_random('normal', ift.UnstructuredDomain((1,)))
inp = f.to_global_data()[0]
ret = op(f).to_global_data()
assert_allclose(ret[ind], inp)
assert_allclose(np.sum(ret), inp + 3*default)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment