From 04a93f9e7d01a7acca6cf309f7eb407f58fee062 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Tue, 15 Jan 2019 18:28:57 +0100
Subject: [PATCH] Add ValueInserter

---
 nifty5/__init__.py                         |  1 +
 nifty5/operators/value_inserter.py         | 57 ++++++++++++++++++++++
 test/test_operators/test_adjoint.py        |  9 ++++
 test/test_operators/test_value_inserter.py | 41 ++++++++++++++++
 4 files changed, 108 insertions(+)
 create mode 100644 nifty5/operators/value_inserter.py
 create mode 100644 test/test_operators/test_value_inserter.py

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index 5711a9111..fa5ff526c 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -46,6 +46,7 @@ from .operators.outer_product_operator import OuterProduct
 from .operators.simple_linear_operators import (
     VdotOperator, ConjugationOperator, Realizer,
     FieldAdapter, ducktape, GeometryRemover, NullOperator)
+from .operators.value_inserter import ValueInserter
 from .operators.energy_operators import (
     EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
     BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence)
diff --git a/nifty5/operators/value_inserter.py b/nifty5/operators/value_inserter.py
new file mode 100644
index 000000000..f0b0f3f47
--- /dev/null
+++ b/nifty5/operators/value_inserter.py
@@ -0,0 +1,57 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2019 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
+
+import numpy as np
+
+from ..domain_tuple import DomainTuple
+from ..domains.unstructured_domain import UnstructuredDomain
+from ..field import Field
+from ..sugar import makeDomain
+from .linear_operator import LinearOperator
+
+
+class ValueInserter(LinearOperator):
+    """Inserts one value into a field which is zero otherwise.
+
+    Parameters
+    ----------
+    target : Domain, tuple of Domain or DomainTuple
+    index : iterable of int
+        The index of the target into which the value of the domain shall be
+        inserted.
+    """
+
+    def __init__(self, target, index):
+        self._domain = makeDomain(UnstructuredDomain(1))
+        self._target = DomainTuple.make(target)
+        index = tuple(index)
+        if not all([isinstance(n, int) and n>=0 and n<self.target.shape[i] for i, n in enumerate(index)]):
+            raise TypeError
+        self._index = index
+        self._capability = self.TIMES | self.ADJOINT_TIMES
+        # Check whether index is in bounds
+        np.empty(self.target.shape)[self._index]
+
+    def apply(self, x, mode):
+        self._check_input(x, mode)
+        x = x.to_global_data()
+        if mode == self.TIMES:
+            res = np.zeros(self.target.shape, dtype=x.dtype)
+            res[self._index] = x
+        else:
+            res = np.full((1,), x[self._index], dtype=x.dtype)
+        return Field.from_global_data(self._tgt(mode), res)
diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py
index aae4132df..6ba8e2b1d 100644
--- a/test/test_operators/test_adjoint.py
+++ b/test/test_operators/test_adjoint.py
@@ -15,6 +15,8 @@
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
+from random import randint
+
 import numpy as np
 import pytest
 
@@ -264,3 +266,10 @@ def testOuter(fdomain, domain):
     f = ift.from_random('normal', fdomain)
     op = ift.OuterProduct(f, domain)
     ift.extra.consistency_check(op)
+
+
+@pmp('sp', _h_spaces + _p_spaces + _pow_spaces)
+def testValueInserter(sp):
+    ind = tuple([randint(0, ss-1) for ss in sp.shape])
+    op = ift.ValueInserter(sp, ind)
+    ift.extra.consistency_check(op)
diff --git a/test/test_operators/test_value_inserter.py b/test/test_operators/test_value_inserter.py
new file mode 100644
index 000000000..09e1d05ee
--- /dev/null
+++ b/test/test_operators/test_value_inserter.py
@@ -0,0 +1,41 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2019 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
+
+from random import randint
+
+import numpy as np
+import pytest
+from numpy.testing import assert_
+
+import nifty5 as ift
+
+
+@pytest.mark.parametrize('sp', [
+    ift.RGSpace(4),
+    ift.PowerSpace(ift.RGSpace((4, 4), harmonic=True)),
+    ift.LMSpace(5),
+    ift.HPSpace(4),
+    ift.GLSpace(4)
+])
+def test_value_inserter(sp):
+    ind = tuple([randint(0, ss - 1) for ss in sp.shape])
+    op = ift.ValueInserter(sp, ind)
+    f = ift.from_random('normal', ift.UnstructuredDomain((1,)))
+    inp = f.to_global_data()
+    ret = op(f).to_global_data()
+    assert_(ret[ind] == inp)
+    assert_(np.sum(ret) == inp)
-- 
GitLab