diagonal_operator.py 5.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
21
22
23
24
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from ..nifty_utilities import cast_iseq_to_tuple
Martin Reinecke's avatar
Martin Reinecke committed
25
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
26

27
28

class DiagonalOperator(EndomorphicOperator):
Theo Steininger's avatar
Theo Steininger committed
29
30
31
32
33
    """ NIFTY class for diagonal operators.

    The NIFTY DiagonalOperator class is a subclass derived from the
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
34

35
36
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
37
    diagonal : Field
38
39
        The diagonal entries of the operator
        (already containing volume factors).
40
41
42
43
44
45
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
        If None, use the domain of "diagonal".
    spaces : tuple of int
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
46
47
48

    Attributes
    ----------
49
50
51
52
53
54
55
56
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
    target : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain in which the outcome of the operator lives. As the Operator
        is endomorphic this is the same as its domain.
    unitary : boolean
        Indicates whether the Operator is unitary or not.
    self_adjoint : boolean
57
        Indicates whether the operator is self-adjoint or not.
58

Martin Reinecke's avatar
Martin Reinecke committed
59
    NOTE: the fields given to __init__ and returned from .diagonal() are
60
61
    considered to be non-bare, i.e. during operator application, no additional
    volume factors are applied!
Martin Reinecke's avatar
Martin Reinecke committed
62

63
64
65
66
67
    See Also
    --------
    EndomorphicOperator
    """

68
    def __init__(self, diagonal, domain=None, spaces=None):
69
        super(DiagonalOperator, self).__init__()
70

Martin Reinecke's avatar
Martin Reinecke committed
71
72
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
            if diagonal.domain != self._domain:
                raise ValueError("domain mismatch")
        else:
            self._spaces = cast_iseq_to_tuple(spaces)
            nspc = len(self._spaces)
            if nspc != len(diagonal.domain.domains):
                raise ValueError("spaces and domain must have the same length")
            if nspc > len(self._domain.domains):
                raise ValueError("too many spaces")
            if nspc > len(set(self._spaces)):
                raise ValueError("non-unique space indices")
Martin Reinecke's avatar
Martin Reinecke committed
90
            # if nspc==len(self.diagonal.domain),
Martin Reinecke's avatar
Martin Reinecke committed
91
92
            # we could do some optimization
            for i, j in enumerate(self._spaces):
93
94
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
95
96
97
98
99
100
101
102
103
104
            if self._spaces == tuple(range(len(self._domain.domains))):
                self._spaces = None  # shortcut

        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

            self._reshaper = [shp if i in active_axes else 1
                              for i, shp in enumerate(self._domain.shape)]
105

106
        self._diagonal = diagonal.copy()
107
108
        self._self_adjoint = None
        self._unitary = None
109

110
    def _times(self, x):
111
        return self._times_helper(x, self._diagonal)
112

113
    def _adjoint_times(self, x):
114
        return self._times_helper(x, self._diagonal.conj())
115

116
    def _inverse_times(self, x):
117
        return self._times_helper(x, 1./self._diagonal)
118

119
    def _adjoint_inverse_times(self, x):
120
        return self._times_helper(x, 1./self._diagonal.conj())
121

122
    def diagonal(self):
123
124
125
126
127
128
129
        """ Returns the diagonal of the Operator.

        Returns
        -------
        out : Field
            The diagonal of the Operator.
        """
130
        return self._diagonal.copy()
131

132
133
    @property
    def domain(self):
134
        return self._domain
135

136
    @property
Martin Reinecke's avatar
Martin Reinecke committed
137
138
    def self_adjoint(self):
        if self._self_adjoint is None:
Martin Reinecke's avatar
Martin Reinecke committed
139
140
141
142
            if not issubclass(self._diagonal.dtype.type, np.complexfloating):
                self._self_adjoint = True
            else:
                self._self_adjoint = (self._diagonal.val.imag == 0).all()
Martin Reinecke's avatar
Martin Reinecke committed
143
        return self._self_adjoint
144
145
146

    @property
    def unitary(self):
147
        if self._unitary is None:
Martin Reinecke's avatar
Martin Reinecke committed
148
            self._unitary = (abs(self._diagonal.val) == 1.).all()
149
150
        return self._unitary

151
    def _times_helper(self, x, diag):
152
        if self._spaces is None:
153
            return diag*x
154

Martin Reinecke's avatar
Martin Reinecke committed
155
156
        reshaped_local_diagonal = np.reshape(dobj.to_global_data(diag.val), self._reshaper)
        if 0 in self._spaces:
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
157
            reshaped_local_diagonal = dobj.local_data(dobj.from_global_data(reshaped_local_diagonal))
158
        return Field(x.domain, val=x.val*reshaped_local_diagonal)