From fbe00860d4ef03b70d01f32d54c723dc2204da02 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Thu, 4 Oct 2018 10:34:23 +0200
Subject: [PATCH] Add ValueInserter

---
 nifty5/__init__.py                          |  2 +-
 nifty5/operators/simple_linear_operators.py | 36 ++++++++++++++++++++-
 test/test_operators/test_adjoint.py         |  8 +++++
 3 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index 153e00464..df1e4b47b 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -46,7 +46,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
 from .operators.outer_product_operator import OuterProduct
 from .operators.simple_linear_operators import (
     VdotOperator, ConjugationOperator, Realizer,
-    FieldAdapter, GeometryRemover, NullOperator)
+    FieldAdapter, GeometryRemover, NullOperator, ValueInserter)
 from .operators.energy_operators import (
     EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
     BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence)
diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py
index 1592ce024..1bce95fc8 100644
--- a/nifty5/operators/simple_linear_operators.py
+++ b/nifty5/operators/simple_linear_operators.py
@@ -18,13 +18,14 @@
 
 from __future__ import absolute_import, division, print_function
 
+import numpy as np
+
 from ..compat import *
 from ..domain_tuple import DomainTuple
 from ..domains.unstructured_domain import UnstructuredDomain
 from ..field import Field
 from ..multi_domain import MultiDomain
 from ..multi_field import MultiField
-from ..sugar import full
 from .endomorphic_operator import EndomorphicOperator
 from .linear_operator import LinearOperator
 
@@ -141,3 +142,36 @@ class NullOperator(LinearOperator):
     def apply(self, x, mode):
         self._check_input(x, mode)
         return self._nullfield(self._tgt(mode))
+
+
+class ValueInserter(LinearOperator):
+    """Operator which inserts one value into a field.
+
+    Parameters
+    ----------
+    target : Domain, tuple of Domain or DomainTuple
+    index : tuple
+        The index of the target into which the value of the domain shall be
+        written.
+    """
+
+    def __init__(self, target, index):
+        from ..sugar import makeDomain
+        self._domain = makeDomain(UnstructuredDomain(1))
+        self._target = DomainTuple.make(target)
+        if not isinstance(index, tuple):
+            raise TypeError
+        self._index = index
+        self._capability = self.TIMES | self.ADJOINT_TIMES
+        # Check if index is in bounds
+        np.empty(self.target.shape)[self._index]
+
+    def apply(self, x, mode):
+        self._check_input(x, mode)
+        x = x.to_global_data()
+        if mode == self.TIMES:
+            res = np.zeros(self.target.shape, dtype=x.dtype)
+            res[self._index] = x
+        else:
+            res = x[self._index]
+        return Field.from_global_data(self._tgt(mode), res)
diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py
index 32597675c..e92e2fa1d 100644
--- a/test/test_operators/test_adjoint.py
+++ b/test/test_operators/test_adjoint.py
@@ -44,6 +44,14 @@ class Consistency_Tests(unittest.TestCase):
         op = ift.LOSResponse(sp, starts, ends, sigma_low, sigma_ups)
         ift.extra.consistency_check(op, dtype, dtype)
 
+    def testValueInserter(self):
+        op = ift.ValueInserter(ift.RGSpace([23, 44]), (2, 43))
+        ift.extra.consistency_check(op)
+        self.assertRaises(IndexError,
+                          lambda: ift.ValueInserter(ift.RGSpace(3), (7,)))
+        self.assertRaises(TypeError,
+                          lambda: ift.ValueInserter(ift.RGSpace(3), 2))
+
     @expand(product(_h_spaces + _p_spaces + _pow_spaces,
                     [np.float64, np.complex128]))
     def testOperatorCombinations(self, sp, dtype):
-- 
GitLab