power_projection_operator.py 3.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
from .. import Field, DomainTuple
20
21
22
23
24
from ..spaces import PowerSpace
from .linear_operator import LinearOperator
from .. import dobj
import numpy as np

25

26
class PowerProjectionOperator(LinearOperator):
27
    def __init__(self, domain, power_space=None, space=None):
28
29
30
31
32
33
34
        super(PowerProjectionOperator, self).__init__()

        # Initialize domain and target
        self._domain = DomainTuple.make(domain)
        if space is None and len(self._domain) == 1:
            space = 0
        space = int(space)
35
        if space < 0 or space >= len(self.domain):
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            raise ValueError("space index out of range")
        hspace = self._domain[space]
        if not hspace.harmonic:
            raise ValueError("Operator acts on harmonic spaces only")
        if power_space is None:
            power_space = PowerSpace(hspace)
        else:
            if not isinstance(power_space, PowerSpace):
                raise TypeError("power_space argument must be a PowerSpace")
            if power_space.harmonic_partner != hspace:
                raise ValueError("power_space does not match its partner")

        self._space = space
        tgt = list(self._domain)
        tgt[self._space] = power_space
        self._target = DomainTuple.make(tgt)

    def _times(self, x):
        pindex = self._target[self._space].pindex
        pindex = pindex.reshape((1, pindex.size, 1))
56
57
        arr = x.weight(1).val.reshape(
                              x.domain.collapsed_shape_for_domain(self._space))
58
        out = dobj.zeros(self._target.collapsed_shape_for_domain(self._space),
59
                         dtype=x.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
60
61
62
        out = dobj.to_ndarray(out)
        np.add.at(out, (slice(None), dobj.to_ndarray(pindex.ravel()), slice(None)), dobj.to_ndarray(arr))
        out = dobj.from_ndarray(out)
63
64
        return Field(self._target, out.reshape(self._target.shape))\
            .weight(-1, spaces=self._space)
65
66
67
68
69

    def _adjoint_times(self, x):
        pindex = self._target[self._space].pindex
        pindex = pindex.reshape((1, pindex.size, 1))
        arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space))
Martin Reinecke's avatar
Martin Reinecke committed
70
        out = arr[(slice(None), dobj.to_ndarray(pindex.ravel()), slice(None))]
71
72
73
74
75
76
77
78
79
80
81
82
83
        return Field(self._domain, out.reshape(self._domain.shape))

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
        return False