contraction_operator.py 2.49 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

Philipp Arras's avatar
Fixups  
Philipp Arras committed
21
import numpy as np
22

Philipp Arras's avatar
Philipp Arras committed
23
from .. import utilities
24
from ..compat import *
Philipp Arras's avatar
Fixups  
Philipp Arras committed
25
from ..domain_tuple import DomainTuple
26
from ..field import Field
Philipp Arras's avatar
Fixups  
Philipp Arras committed
27
28
29
from .linear_operator import LinearOperator


30
31
class ContractionOperator(LinearOperator):
    """A linear operator which sums up fields into the direction of subspaces.
Philipp Arras's avatar
Philipp Arras committed
32

33
34
    This ContractionOperator sums up a field with is defined on a DomainTuple
    to a DomainTuple which contains the former as a subset.
Philipp Arras's avatar
Philipp Arras committed
35
36
37

    Parameters
    ----------
38
    domain : Domain, tuple of Domain or DomainTuple
Philipp Arras's avatar
Philipp Arras committed
39
    spaces : int or tuple of int
40
        The elements of "domain" which are taken as target.
Philipp Arras's avatar
Philipp Arras committed
41
42
    """

43
44
45
46
47
    def __init__(self, domain, spaces):
        self._domain = DomainTuple.make(domain)
        self._spaces = utilities.parse_spaces(spaces, len(self._domain))
        self._target = [
            dom for i, dom in enumerate(self._domain) if i in self._spaces
Philipp Arras's avatar
Philipp Arras committed
48
        ]
49
        self._target = DomainTuple.make(self._target)
Martin Reinecke's avatar
Martin Reinecke committed
50
        self._capability = self.TIMES | self.ADJOINT_TIMES
Philipp Arras's avatar
Fixups  
Philipp Arras committed
51
52
53

    def apply(self, x, mode):
        self._check_input(x, mode)
54
        if mode == self.ADJOINT_TIMES:
55
56
            ldat = x.local_data if 0 in self._spaces else x.to_global_data()
            shp = []
57
58
59
60
61
            for i, dom in enumerate(self._domain):
                tmp = dom.shape if i > 0 else dom.local_shape
                shp += tmp if i in self._spaces else (1,)*len(dom.shape)
            ldat = np.broadcast_to(ldat.reshape(shp), self._domain.local_shape)
            return Field.from_local_data(self._domain, ldat)
Philipp Arras's avatar
Fixups  
Philipp Arras committed
62
        else:
Philipp Arras's avatar
Philipp Arras committed
63
64
            return x.sum(
                [s for s in range(len(x.domain)) if s not in self._spaces])