domain_distributor.py 2.09 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

Philipp Arras's avatar
Fixups    
Philipp Arras committed
21
import numpy as np
22

Philipp Arras's avatar
Philipp Arras committed
23
from .. import utilities
24
from ..compat import *
Philipp Arras's avatar
Fixups    
Philipp Arras committed
25
from ..domain_tuple import DomainTuple
26
from ..field import Field
Philipp Arras's avatar
Fixups    
Philipp Arras committed
27
28
29
30
from .linear_operator import LinearOperator


class DomainDistributor(LinearOperator):
31
    def __init__(self, target, spaces):
Philipp Arras's avatar
Fixups    
Philipp Arras committed
32
        self._target = DomainTuple.make(target)
33
34
35
36
        self._spaces = utilities.parse_spaces(spaces, len(self._target))
        self._domain = [tgt for i, tgt in enumerate(self._target)
                        if i in self._spaces]
        self._domain = DomainTuple.make(self._domain)
Martin Reinecke's avatar
Martin Reinecke committed
37
        self._capability = self.TIMES | self.ADJOINT_TIMES
Philipp Arras's avatar
Fixups    
Philipp Arras committed
38
39
40
41

    def apply(self, x, mode):
        self._check_input(x, mode)
        if mode == self.TIMES:
42
43
44
45
46
47
48
            ldat = x.local_data if 0 in self._spaces else x.to_global_data()
            shp = []
            for i, tgt in enumerate(self._target):
                tmp = tgt.shape if i > 0 else tgt.local_shape
                shp += tmp if i in self._spaces else(1,)*len(tgt.shape)
            ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape)
            return Field.from_local_data(self._target, ldat)
Philipp Arras's avatar
Fixups    
Philipp Arras committed
49
        else:
50
51
            return x.sum([s for s in range(len(x.domain))
                          if s not in self._spaces])