From fbe00860d4ef03b70d01f32d54c723dc2204da02 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Thu, 4 Oct 2018 10:34:23 +0200 Subject: [PATCH] Add ValueInserter --- nifty5/__init__.py | 2 +- nifty5/operators/simple_linear_operators.py | 36 ++++++++++++++++++++- test/test_operators/test_adjoint.py | 8 +++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 153e00464..df1e4b47b 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -46,7 +46,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator from .operators.outer_product_operator import OuterProduct from .operators.simple_linear_operators import ( VdotOperator, ConjugationOperator, Realizer, - FieldAdapter, GeometryRemover, NullOperator) + FieldAdapter, GeometryRemover, NullOperator, ValueInserter) from .operators.energy_operators import ( EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence) diff --git a/nifty5/operators/simple_linear_operators.py b/nifty5/operators/simple_linear_operators.py index 1592ce024..1bce95fc8 100644 --- a/nifty5/operators/simple_linear_operators.py +++ b/nifty5/operators/simple_linear_operators.py @@ -18,13 +18,14 @@ from __future__ import absolute_import, division, print_function +import numpy as np + from ..compat import * from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain from ..field import Field from ..multi_domain import MultiDomain from ..multi_field import MultiField -from ..sugar import full from .endomorphic_operator import EndomorphicOperator from .linear_operator import LinearOperator @@ -141,3 +142,36 @@ class NullOperator(LinearOperator): def apply(self, x, mode): self._check_input(x, mode) return self._nullfield(self._tgt(mode)) + + +class ValueInserter(LinearOperator): + """Operator which inserts one value into a field. + + Parameters + ---------- + target : Domain, tuple of Domain or DomainTuple + index : tuple + The index of the target into which the value of the domain shall be + written. + """ + + def __init__(self, target, index): + from ..sugar import makeDomain + self._domain = makeDomain(UnstructuredDomain(1)) + self._target = DomainTuple.make(target) + if not isinstance(index, tuple): + raise TypeError + self._index = index + self._capability = self.TIMES | self.ADJOINT_TIMES + # Check if index is in bounds + np.empty(self.target.shape)[self._index] + + def apply(self, x, mode): + self._check_input(x, mode) + x = x.to_global_data() + if mode == self.TIMES: + res = np.zeros(self.target.shape, dtype=x.dtype) + res[self._index] = x + else: + res = x[self._index] + return Field.from_global_data(self._tgt(mode), res) diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index 32597675c..e92e2fa1d 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -44,6 +44,14 @@ class Consistency_Tests(unittest.TestCase): op = ift.LOSResponse(sp, starts, ends, sigma_low, sigma_ups) ift.extra.consistency_check(op, dtype, dtype) + def testValueInserter(self): + op = ift.ValueInserter(ift.RGSpace([23, 44]), (2, 43)) + ift.extra.consistency_check(op) + self.assertRaises(IndexError, + lambda: ift.ValueInserter(ift.RGSpace(3), (7,))) + self.assertRaises(TypeError, + lambda: ift.ValueInserter(ift.RGSpace(3), 2)) + @expand(product(_h_spaces + _p_spaces + _pow_spaces, [np.float64, np.complex128])) def testOperatorCombinations(self, sp, dtype): -- GitLab