diagonal_operator.py 7.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19
20
from __future__ import (absolute_import, division, print_function)
from builtins import *
21
import numpy as np
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
22
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
23
24
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
25
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
26
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
27

28
29

class DiagonalOperator(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
30
    """ NIFTy class for diagonal operators.
Theo Steininger's avatar
Theo Steininger committed
31

Martin Reinecke's avatar
Martin Reinecke committed
32
    The NIFTy DiagonalOperator class is a subclass derived from the
Theo Steininger's avatar
Theo Steininger committed
33
34
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
35

36
37
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
38
    diagonal : Field
Martin Reinecke's avatar
docs    
Martin Reinecke committed
39
40
        The diagonal entries of the operator.
    domain : Domain, tuple of Domain or DomainTuple, optional
41
42
        The domain on which the Operator's input Field lives.
        If None, use the domain of "diagonal".
Martin Reinecke's avatar
docs    
Martin Reinecke committed
43
    spaces : int or tuple of int, optional
44
45
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
49
50
51
52
53
54
55

    Notes
    -----
    Formally, this operator always supports all operation modes (times,
    adjoint_times, inverse_times and inverse_adjoint_times), even if there
    are diagonal elements with value 0 or infinity. It is the user's
    responsibility to apply the operator only in appropriate ways (e.g. call
    inverse_times only if there are no zeros on the diagonal).

    This shortcoming will hopefully be fixed in the future.
56
57
    """

Martin Reinecke's avatar
Martin Reinecke committed
58
    def __init__(self, diagonal, domain=None, spaces=None):
59
        super(DiagonalOperator, self).__init__()
60

Martin Reinecke's avatar
Martin Reinecke committed
61
62
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
63
64
65
66
67
68
69
70
71
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
            if diagonal.domain != self._domain:
                raise ValueError("domain mismatch")
        else:
72
73
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            if len(self._spaces) != len(diagonal.domain):
74
                raise ValueError("spaces and domain must have the same length")
Martin Reinecke's avatar
Martin Reinecke committed
75
            for i, j in enumerate(self._spaces):
76
77
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
78
            if self._spaces == tuple(range(len(self._domain))):
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
79
80
81
82
83
84
85
                self._spaces = None  # shortcut

        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

Martin Reinecke's avatar
Martin Reinecke committed
86
            if self._spaces[0] == 0:
Martin Reinecke's avatar
Martin Reinecke committed
87
                self._ldiag = diagonal.local_data
Martin Reinecke's avatar
Martin Reinecke committed
88
            else:
Martin Reinecke's avatar
Martin Reinecke committed
89
                self._ldiag = diagonal.to_global_data()
Martin Reinecke's avatar
Martin Reinecke committed
90
            locshape = dobj.local_shape(self._domain.shape, 0)
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
91
            self._reshaper = [shp if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
92
93
94
                              for i, shp in enumerate(locshape)]
            self._ldiag = self._ldiag.reshape(self._reshaper)
        else:
Martin Reinecke's avatar
Martin Reinecke committed
95
            self._ldiag = diagonal.local_data
96
97
98
        self._update_diagmin()

    def _update_diagmin(self):
Martin Reinecke's avatar
Martin Reinecke committed
99
        self._ldiag.flags.writeable = False
100
101
102
        if not np.issubdtype(self._ldiag.dtype, np.complexfloating):
            lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
            self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
Martin Reinecke's avatar
Martin Reinecke committed
103

104
    def _from_ldiag(self, spc, ldiag):
Martin Reinecke's avatar
Martin Reinecke committed
105
106
107
108
109
110
        res = DiagonalOperator.__new__(DiagonalOperator)
        res._domain = self._domain
        if self._spaces is None or spc is None:
            res._spaces = None
        else:
            res._spaces = tuple(set(self._spaces) | set(spc))
111
112
        res._ldiag = ldiag
        res._update_diagmin()
Martin Reinecke's avatar
Martin Reinecke committed
113
114
115
116
117
        return res

    def _scale(self, fct):
        if not np.isscalar(fct):
            raise TypeError("scalar value required")
118
        return self._from_ldiag((), self._ldiag*fct)
Martin Reinecke's avatar
Martin Reinecke committed
119
120
121
122

    def _add(self, sum):
        if not np.isscalar(sum):
            raise TypeError("scalar value required")
123
        return self._from_ldiag((), self._ldiag+sum)
Martin Reinecke's avatar
Martin Reinecke committed
124
125
126
127

    def _combine_prod(self, op):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
128
        return self._from_ldiag(op._spaces, self._ldiag*op._ldiag)
Martin Reinecke's avatar
Martin Reinecke committed
129
130
131
132

    def _combine_sum(self, op, selfneg, opneg):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
133
134
135
        tdiag = (self._ldiag * (-1 if selfneg else 1) +
                 op._ldiag * (-1 if opneg else 1))
        return self._from_ldiag(op._spaces, tdiag)
136

Martin Reinecke's avatar
Martin Reinecke committed
137
138
    def apply(self, x, mode):
        self._check_input(x, mode)
139

Martin Reinecke's avatar
Martin Reinecke committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if mode == self.TIMES:
            return Field(x.domain, val=x.val*self._ldiag)
        elif mode == self.ADJOINT_TIMES:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val*self._ldiag)
            else:
                return Field(x.domain, val=x.val*self._ldiag.conj())
        elif mode == self.INVERSE_TIMES:
            return Field(x.domain, val=x.val/self._ldiag)
        else:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val/self._ldiag)
            else:
                return Field(x.domain, val=x.val/self._ldiag.conj())
154

155
156
    @property
    def domain(self):
157
        return self._domain
158

159
    @property
Martin Reinecke's avatar
Martin Reinecke committed
160
    def capability(self):
Martin Reinecke's avatar
Martin Reinecke committed
161
162
        return self._all_ops

163
164
165
166
167
    def _flip_modes(self, trafo):
        ADJ = self.ADJOINT_BIT
        INV = self.INVERSE_BIT

        if trafo == 0:
Martin Reinecke's avatar
Martin Reinecke committed
168
            return self
169
        if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
170
            return self
171
        if trafo == ADJ:
172
            return self._from_ldiag((), self._ldiag.conjugate())
173
        elif trafo == INV:
174
            return self._from_ldiag((), 1./self._ldiag)
175
        elif trafo == ADJ | INV:
176
177
            return self._from_ldiag((), 1./self._ldiag.conjugate())
        raise ValueError("invalid operator transformation")
178

179
    def draw_sample(self, from_inverse=False, dtype=np.float64):
180
        if np.issubdtype(self._ldiag.dtype, np.complexfloating):
clienhar's avatar
clienhar committed
181
            raise ValueError("operator not positive definite")
182
183
184
185
        if self._diagmin < 0.:
            raise ValueError("operator not positive definite")
        if self._diagmin == 0. and from_inverse:
            raise ValueError("operator not positive definite")
186
187
        res = Field.from_random(random_type="normal", domain=self._domain,
                                dtype=dtype)
188
        if from_inverse:
189
            return res/np.sqrt(self._ldiag)
190
        else:
191
            return res*np.sqrt(self._ldiag)