domain_tuple_field_inserter.py 2.48 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

Philipp Arras's avatar
Fixups  
Philipp Arras committed
18
import numpy as np
19
20

from ..compat import *
Philipp Arras's avatar
Fixups  
Philipp Arras committed
21
from ..domain_tuple import DomainTuple
22
from ..field import Field
Philipp Arras's avatar
Fixups  
Philipp Arras committed
23
24
25
from .linear_operator import LinearOperator


Philipp Arras's avatar
Philipp Arras committed
26
class DomainTupleFieldInserter(LinearOperator):
27
    def __init__(self, domain, new_space, index, position):
Philipp Arras's avatar
Philipp Arras committed
28
29
30
31
32
33
        '''Writes the content of a field into one slice of a DomainTuple.

        Parameters
        ----------
        domain : Domain, tuple of Domain or DomainTuple
        new_space : Domain, tuple of Domain or DomainTuple
34
        index : Integer
Philipp Arras's avatar
Fixup  
Philipp Arras committed
35
            Index at which new_space shall be added to domain.
36
        position : tuple
Philipp Arras's avatar
Fixup  
Philipp Arras committed
37
            Slice in new_space in which the input field shall be written into.
Philipp Arras's avatar
Philipp Arras committed
38
39
        '''
        self._domain = DomainTuple.make(domain)
40
41
42
        tgt = list(self.domain)
        tgt.insert(index, new_space)
        self._target = DomainTuple.make(tgt)
Philipp Arras's avatar
Philipp Arras committed
43
        self._capability = self.TIMES | self.ADJOINT_TIMES
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
44
45
46
47
48
49
50
51
        fst_dims = sum(len(dd.shape) for dd in self.domain[:index])
        nshp = new_space.shape
        if len(position) != len(nshp):
            raise ValueError("shape mismatch between new_space and position")
        for s, p in zip(nshp, position):
            if p < 0 or p >= s:
                raise ValueError("bad position value")
        self._slc = (slice(None),)*fst_dims + position
Philipp Arras's avatar
Philipp Arras committed
52
53
54

    def apply(self, x, mode):
        self._check_input(x, mode)
Philipp Arras's avatar
Philipp Arras committed
55
        # FIXME Make fully MPI compatible without global_data
Philipp Arras's avatar
Philipp Arras committed
56
57
        if mode == self.TIMES:
            res = np.zeros(self.target.shape, dtype=x.dtype)
58
            res[self._slc] = x.to_global_data()
Philipp Arras's avatar
Philipp Arras committed
59
60
            return Field.from_global_data(self.target, res)
        else:
61
62
            return Field.from_global_data(self.domain,
                                          x.to_global_data()[self._slc])