value_inserter.py 2.21 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np

from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..sugar import makeDomain
from .linear_operator import LinearOperator


class ValueInserter(LinearOperator):
28
    """Inserts one value into a field which is zero otherwise.
Philipp Arras's avatar
Philipp Arras committed
29
30
31
32
33
34
35
36
37

    Parameters
    ----------
    target : Domain, tuple of Domain or DomainTuple
    index : iterable of int
        The index of the target into which the value of the domain shall be
        inserted.
    """

38
    def __init__(self, target, index):
Philipp Arras's avatar
Philipp Arras committed
39
40
41
        self._domain = makeDomain(UnstructuredDomain(1))
        self._target = DomainTuple.make(target)
        index = tuple(index)
Philipp Arras's avatar
Philipp Arras committed
42
43
44
45
        if not all([
                isinstance(n, int) and n >= 0 and n < self.target.shape[i]
                for i, n in enumerate(index)
        ]):
Philipp Arras's avatar
Philipp Arras committed
46
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
47
48
        if not len(index) == len(self.target.shape):
            raise ValueError
Philipp Arras's avatar
Philipp Arras committed
49
50
        self._index = index
        self._capability = self.TIMES | self.ADJOINT_TIMES
51
52
        # Check whether index is in bounds
        np.empty(self.target.shape)[self._index]
Philipp Arras's avatar
Philipp Arras committed
53
54
55
56
57

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.to_global_data()
        if mode == self.TIMES:
58
            res = np.zeros(self.target.shape, dtype=x.dtype)
Philipp Arras's avatar
Philipp Arras committed
59
60
            res[self._index] = x
        else:
61
            res = np.full((1,), x[self._index], dtype=x.dtype)
Philipp Arras's avatar
Philipp Arras committed
62
        return Field.from_global_data(self._tgt(mode), res)