Commit fbe00860 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add ValueInserter

parent 7d9c0432
......@@ -46,7 +46,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator)
FieldAdapter, GeometryRemover, NullOperator, ValueInserter)
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence)
......@@ -18,13 +18,14 @@
from __future__ import absolute_import, division, print_function
import numpy as np
from ..compat import *
from ..domain_tuple import DomainTuple
from import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import full
from .endomorphic_operator import EndomorphicOperator
from .linear_operator import LinearOperator
......@@ -141,3 +142,36 @@ class NullOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return self._nullfield(self._tgt(mode))
class ValueInserter(LinearOperator):
"""Operator which inserts one value into a field.
target : Domain, tuple of Domain or DomainTuple
index : tuple
The index of the target into which the value of the domain shall be
def __init__(self, target, index):
from ..sugar import makeDomain
self._domain = makeDomain(UnstructuredDomain(1))
self._target = DomainTuple.make(target)
if not isinstance(index, tuple):
raise TypeError
self._index = index
self._capability = self.TIMES | self.ADJOINT_TIMES
# Check if index is in bounds
def apply(self, x, mode):
self._check_input(x, mode)
x = x.to_global_data()
if mode == self.TIMES:
res = np.zeros(, dtype=x.dtype)
res[self._index] = x
res = x[self._index]
return Field.from_global_data(self._tgt(mode), res)
......@@ -44,6 +44,14 @@ class Consistency_Tests(unittest.TestCase):
op = ift.LOSResponse(sp, starts, ends, sigma_low, sigma_ups)
ift.extra.consistency_check(op, dtype, dtype)
def testValueInserter(self):
op = ift.ValueInserter(ift.RGSpace([23, 44]), (2, 43))
lambda: ift.ValueInserter(ift.RGSpace(3), (7,)))
lambda: ift.ValueInserter(ift.RGSpace(3), 2))
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testOperatorCombinations(self, sp, dtype):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment