diagonal_operator.py 6.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
20
from __future__ import division
from builtins import range
21
22
import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
23
from ...field import Field
Martin Reinecke's avatar
Martin Reinecke committed
24
from ...domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
25
from ..endomorphic_operator import EndomorphicOperator
26
27
28


class DiagonalOperator(EndomorphicOperator):
Theo Steininger's avatar
Theo Steininger committed
29
30
31
32
33
    """ NIFTY class for diagonal operators.

    The NIFTY DiagonalOperator class is a subclass derived from the
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
34

35
36
37

    Parameters
    ----------
Theo Steininger's avatar
Theo Steininger committed
38
39
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
Martin Reinecke's avatar
Martin Reinecke committed
40
    diagonal : {scalar, list, array, Field}
41
42
43
        The diagonal entries of the operator.
    copy : boolean
        Internal copy of the diagonal (default: True)
44
45
46
    default_spaces : tuple of ints *optional*
        Defines on which space(s) of a given field the Operator acts by
        default (default: None)
47
48
49

    Attributes
    ----------
50
51
52
53
54
55
56
57
58
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
    target : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain in which the outcome of the operator lives. As the Operator
        is endomorphic this is the same as its domain.
    unitary : boolean
        Indicates whether the Operator is unitary or not.
    self_adjoint : boolean
        Indicates whether the operator is self_adjoint or not.
59
60
61
62
63
64
65
66
67
68

    Raises
    ------

    See Also
    --------
    EndomorphicOperator

    """

69
70
    # ---Overwritten properties and methods---

Martin Reinecke's avatar
Martin Reinecke committed
71
    def __init__(self, domain=(), diagonal=None, copy=True,
Martin Reinecke's avatar
stage1    
Martin Reinecke committed
72
                 default_spaces=None):
73
74
        super(DiagonalOperator, self).__init__(default_spaces)

Martin Reinecke's avatar
Martin Reinecke committed
75
        self._domain = DomainTuple.make(domain)
76

77
78
        self._self_adjoint = None
        self._unitary = None
Martin Reinecke's avatar
Martin Reinecke committed
79
        self.set_diagonal(diagonal=diagonal, copy=copy)
80

81
82
    def _times(self, x, spaces):
        return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
83

84
85
    def _adjoint_times(self, x, spaces):
        return self._times_helper(x, spaces,
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
86
                                  operation=lambda z: z.conjugate().__mul__)
87

88
    def _inverse_times(self, x, spaces):
Martin Reinecke's avatar
Martin Reinecke committed
89
90
        return self._times_helper(x, spaces,
                                  operation=lambda z: z.__rtruediv__)
91

92
93
    def _adjoint_inverse_times(self, x, spaces):
        return self._times_helper(x, spaces,
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
94
                                  operation=lambda z:
95
                                      z.conjugate().__rtruediv__)
96

Martin Reinecke's avatar
Martin Reinecke committed
97
    def diagonal(self, copy=True):
98
99
100
101
102
103
104
105
106
107
108
109
110
        """ Returns the diagonal of the Operator.

        Parameters
        ----------
        copy : boolean
            Whether the returned Field should be copied or not.

        Returns
        -------
        out : Field
            The diagonal of the Operator.

        """
Martin Reinecke's avatar
Martin Reinecke committed
111
        return self._diagonal.copy() if copy else self._diagonal
112

Martin Reinecke's avatar
Martin Reinecke committed
113
    def inverse_diagonal(self):
114
115
116
117
118
119
120
121
        """ Returns the inverse-diagonal of the operator.

        Returns
        -------
        out : Field
            The inverse of the diagonal of the Operator.

        """
Martin Reinecke's avatar
Martin Reinecke committed
122
        return 1./self.diagonal(copy=False)
123

124
125
    # ---Mandatory properties and methods---

126
127
128
129
    @property
    def domain(self):
        return self._domain

130
    @property
Martin Reinecke's avatar
Martin Reinecke committed
131
132
    def self_adjoint(self):
        if self._self_adjoint is None:
Martin Reinecke's avatar
Martin Reinecke committed
133
134
135
136
            if not issubclass(self._diagonal.dtype.type, np.complexfloating):
                self._self_adjoint = True
            else:
                self._self_adjoint = (self._diagonal.val.imag == 0).all()
Martin Reinecke's avatar
Martin Reinecke committed
137
        return self._self_adjoint
138
139
140

    @property
    def unitary(self):
141
        if self._unitary is None:
Martin Reinecke's avatar
Martin Reinecke committed
142
            self._unitary = (abs(self._diagonal.val) == 1.).all()
143
144
145
146
        return self._unitary

    # ---Added properties and methods---

Martin Reinecke's avatar
Martin Reinecke committed
147
    def set_diagonal(self, diagonal, copy=True):
148
149
150
151
        """ Sets the diagonal of the Operator.

        Parameters
        ----------
Martin Reinecke's avatar
Martin Reinecke committed
152
        diagonal : {scalar, list, array, Field}
153
154
            The diagonal entries of the operator.
        copy : boolean
Theo Steininger's avatar
Theo Steininger committed
155
            Specifies if a copy of the input shall be made (default: True).
156
157
158

        """

159
        # use the casting functionality from Field to process `diagonal`
Martin Reinecke's avatar
Martin Reinecke committed
160
        f = Field(domain=self.domain, val=diagonal, copy=copy)
161

Martin Reinecke's avatar
Martin Reinecke committed
162
163
        # Reset the self_adjoint property:
        self._self_adjoint = None
164

165
166
        # Reset the unitarity property
        self._unitary = None
167
168
169

        # store the diagonal-field
        self._diagonal = f
170

171
172
    def _times_helper(self, x, spaces, operation):
        # if the domain matches directly
173
        # -> multiply the fields directly
174
        if x.domain == self.domain:
175
            # here the actual multiplication takes place
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
176
            return operation(self._diagonal)(x)
177
178

        if spaces is None:
Martin Reinecke's avatar
Martin Reinecke committed
179
            active_axes = range(len(x.shape))
180
        else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
181
            active_axes = []
182
            for space_index in spaces:
Martin Reinecke's avatar
Martin Reinecke committed
183
                active_axes += x.domain.axes[space_index]
184

Martin Reinecke's avatar
Martin Reinecke committed
185
        reshaper = [x.shape[i] if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
186
                    for i in range(len(x.shape))]
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
187
        reshaped_local_diagonal = np.reshape(self._diagonal.val, reshaper)
188
189

        # here the actual multiplication takes place
190
        return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val))