domain_tuple_field_inserter.py 2.45 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

Philipp Arras's avatar
Fixups    
Philipp Arras committed
18
import numpy as np
19

Philipp Arras's avatar
Fixups    
Philipp Arras committed
20
from ..domain_tuple import DomainTuple
21
from ..field import Field
Philipp Arras's avatar
Fixups    
Philipp Arras committed
22
23
24
from .linear_operator import LinearOperator


Philipp Arras's avatar
Philipp Arras committed
25
class DomainTupleFieldInserter(LinearOperator):
26
27
    """Writes the content of a :class:`Field` into one slice of a
    :class:`DomainTuple`.
Philipp Arras's avatar
Philipp Arras committed
28

Philipp Arras's avatar
Philipp Arras committed
29
30
    Parameters
    ----------
31
32
33
34
35
    target : Domain, tuple of Domain or DomainTuple
    space : int
       The index of the sub-domain which is inserted.
    index : tuple
        Slice in new sub-domain in which the input field shall be written into.
36
    """
37
38
39
40
41
42
43
44

    def __init__(self, target, space, pos):
        if not space <= len(target) or space < 0:
            raise ValueError
        self._target = DomainTuple.make(target)
        dom = list(self.target)
        dom.pop(space)
        self._domain = DomainTuple.make(dom)
Philipp Arras's avatar
Philipp Arras committed
45
        self._capability = self.TIMES | self.ADJOINT_TIMES
46
47

        new_space = target[space]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
48
        nshp = new_space.shape
49
50
51
        fst_dims = sum(len(dd.shape) for dd in self.target[:space])

        if len(pos) != len(nshp):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
52
            raise ValueError("shape mismatch between new_space and position")
53
        for s, p in zip(nshp, pos):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
54
55
            if p < 0 or p >= s:
                raise ValueError("bad position value")
56
57

        self._slc = (slice(None),)*fst_dims + pos
Philipp Arras's avatar
Philipp Arras committed
58
59
60

    def apply(self, x, mode):
        self._check_input(x, mode)
Philipp Arras's avatar
Philipp Arras committed
61
        # FIXME Make fully MPI compatible without global_data
Philipp Arras's avatar
Philipp Arras committed
62
63
        if mode == self.TIMES:
            res = np.zeros(self.target.shape, dtype=x.dtype)
64
            res[self._slc] = x.to_global_data()
Philipp Arras's avatar
Philipp Arras committed
65
66
            return Field.from_global_data(self.target, res)
        else:
67
68
            return Field.from_global_data(self.domain,
                                          x.to_global_data()[self._slc])