diagonal_operator.py 5.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
20
from __future__ import division
from builtins import range
21
22
import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
23
from ...field import Field
Martin Reinecke's avatar
Martin Reinecke committed
24
from ...domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
25
from ..endomorphic_operator import EndomorphicOperator
26
from ...nifty_utilities import cast_iseq_to_tuple
27
28

class DiagonalOperator(EndomorphicOperator):
Theo Steininger's avatar
Theo Steininger committed
29
30
31
32
33
    """ NIFTY class for diagonal operators.

    The NIFTY DiagonalOperator class is a subclass derived from the
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
34

35
36
37

    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
38
    diagonal : Field
39
40
41
42
43
44
        The diagonal entries of the operator.
    copy : boolean
        Internal copy of the diagonal (default: True)

    Attributes
    ----------
45
46
47
48
49
50
51
52
53
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
    target : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain in which the outcome of the operator lives. As the Operator
        is endomorphic this is the same as its domain.
    unitary : boolean
        Indicates whether the Operator is unitary or not.
    self_adjoint : boolean
        Indicates whether the operator is self_adjoint or not.
54
55
56
57
58
59
60

    See Also
    --------
    EndomorphicOperator

    """

61
62
    # ---Overwritten properties and methods---

63
64
    def __init__(self, diagonal, domain=None, spaces=None, copy=True):
        super(DiagonalOperator, self).__init__()
65

Martin Reinecke's avatar
Martin Reinecke committed
66
67
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
            if diagonal.domain != self._domain:
                raise ValueError("domain mismatch")
        else:
            self._spaces = cast_iseq_to_tuple(spaces)
            nspc = len(self._spaces)
            if nspc != len(diagonal.domain.domains):
                raise ValueError("spaces and domain must have the same length")
            if nspc > len(self._domain.domains):
                raise ValueError("too many spaces")
            if nspc > len(set(self._spaces)):
                raise ValueError("non-unique space indices")
            # if nspc==len(self.diagonal.domain.domains, we could do some optimization
            for i, j  in enumerate(self._spaces):
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")

Martin Reinecke's avatar
Martin Reinecke committed
90
        self._diagonal = diagonal if not copy else diagonal.copy()
91
92
        self._self_adjoint = None
        self._unitary = None
93

94
95
    def _times(self, x):
        return self._times_helper(x, lambda z: z.__mul__)
96

97
98
    def _adjoint_times(self, x):
        return self._times_helper(x, lambda z: z.conjugate().__mul__)
99

100
101
    def _inverse_times(self, x):
        return self._times_helper(x, lambda z: z.__rtruediv__)
102

103
104
    def _adjoint_inverse_times(self, x):
        return self._times_helper(x, lambda z: z.conjugate().__rtruediv__)
105

Martin Reinecke's avatar
Martin Reinecke committed
106
    def diagonal(self, copy=True):
107
108
109
110
111
112
113
114
115
116
117
118
119
        """ Returns the diagonal of the Operator.

        Parameters
        ----------
        copy : boolean
            Whether the returned Field should be copied or not.

        Returns
        -------
        out : Field
            The diagonal of the Operator.

        """
Martin Reinecke's avatar
Martin Reinecke committed
120
        return self._diagonal.copy() if copy else self._diagonal
121

122
123
    # ---Mandatory properties and methods---

124
125
    @property
    def domain(self):
126
        return self._domain
127

128
    @property
Martin Reinecke's avatar
Martin Reinecke committed
129
130
    def self_adjoint(self):
        if self._self_adjoint is None:
Martin Reinecke's avatar
Martin Reinecke committed
131
132
133
134
            if not issubclass(self._diagonal.dtype.type, np.complexfloating):
                self._self_adjoint = True
            else:
                self._self_adjoint = (self._diagonal.val.imag == 0).all()
Martin Reinecke's avatar
Martin Reinecke committed
135
        return self._self_adjoint
136
137
138

    @property
    def unitary(self):
139
        if self._unitary is None:
Martin Reinecke's avatar
Martin Reinecke committed
140
            self._unitary = (abs(self._diagonal.val) == 1.).all()
141
142
143
144
        return self._unitary

    # ---Added properties and methods---

145
146
    def _times_helper(self, x, operation):
        if self._spaces is None:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
147
            return operation(self._diagonal)(x)
148

149
150
151
        active_axes = []
        for space_index in self._spaces:
            active_axes += x.domain.axes[space_index]
152

Martin Reinecke's avatar
Martin Reinecke committed
153
        reshaper = [x.shape[i] if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
154
                    for i in range(len(x.shape))]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
155
        reshaped_local_diagonal = np.reshape(self._diagonal.val, reshaper)
156
157

        # here the actual multiplication takes place
158
        return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val))