diagonal_operator.py 6.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19
from __future__ import absolute_import, division, print_function
20

21
import numpy as np
22
23
24

from .. import dobj, utilities
from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
25
from ..domain_tuple import DomainTuple
26
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
27
from .endomorphic_operator import EndomorphicOperator
Martin Reinecke's avatar
Martin Reinecke committed
28

29
30

class DiagonalOperator(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
31
    """ NIFTy class for diagonal operators.
Theo Steininger's avatar
Theo Steininger committed
32

Martin Reinecke's avatar
Martin Reinecke committed
33
    The NIFTy DiagonalOperator class is a subclass derived from the
Theo Steininger's avatar
Theo Steininger committed
34
35
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
36

37
38
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
39
    diagonal : Field
Martin Reinecke's avatar
docs    
Martin Reinecke committed
40
41
        The diagonal entries of the operator.
    domain : Domain, tuple of Domain or DomainTuple, optional
42
43
        The domain on which the Operator's input Field lives.
        If None, use the domain of "diagonal".
Martin Reinecke's avatar
docs    
Martin Reinecke committed
44
    spaces : int or tuple of int, optional
45
46
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
Martin Reinecke's avatar
Martin Reinecke committed
47
48
49
50
51
52
53
54
55
56

    Notes
    -----
    Formally, this operator always supports all operation modes (times,
    adjoint_times, inverse_times and inverse_adjoint_times), even if there
    are diagonal elements with value 0 or infinity. It is the user's
    responsibility to apply the operator only in appropriate ways (e.g. call
    inverse_times only if there are no zeros on the diagonal).

    This shortcoming will hopefully be fixed in the future.
57
58
    """

Martin Reinecke's avatar
Martin Reinecke committed
59
    def __init__(self, diagonal, domain=None, spaces=None):
Martin Reinecke's avatar
Martin Reinecke committed
60
61
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
62
63
64
65
66
67
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
Martin Reinecke's avatar
Martin Reinecke committed
68
            if diagonal.domain != self._domain:
69
70
                raise ValueError("domain mismatch")
        else:
71
72
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            if len(self._spaces) != len(diagonal.domain):
73
                raise ValueError("spaces and domain must have the same length")
Martin Reinecke's avatar
Martin Reinecke committed
74
            for i, j in enumerate(self._spaces):
75
76
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
77
            if self._spaces == tuple(range(len(self._domain))):
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
78
79
80
81
82
83
84
                self._spaces = None  # shortcut

        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

Martin Reinecke's avatar
Martin Reinecke committed
85
            if self._spaces[0] == 0:
Martin Reinecke's avatar
Martin Reinecke committed
86
                self._ldiag = diagonal.local_data
Martin Reinecke's avatar
Martin Reinecke committed
87
            else:
Martin Reinecke's avatar
Martin Reinecke committed
88
                self._ldiag = diagonal.to_global_data()
Martin Reinecke's avatar
Martin Reinecke committed
89
            locshape = dobj.local_shape(self._domain.shape, 0)
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
90
            self._reshaper = [shp if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
91
92
93
                              for i, shp in enumerate(locshape)]
            self._ldiag = self._ldiag.reshape(self._reshaper)
        else:
Martin Reinecke's avatar
Martin Reinecke committed
94
            self._ldiag = diagonal.local_data
95
        self._fill_rest()
96

97
    def _fill_rest(self):
Martin Reinecke's avatar
Martin Reinecke committed
98
        self._ldiag.flags.writeable = False
Martin Reinecke's avatar
Martin Reinecke committed
99
        self._complex = utilities.iscomplextype(self._ldiag.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
100
        self._capability = self._all_ops
101
        if not self._complex:
102
103
            lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
            self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
Martin Reinecke's avatar
Martin Reinecke committed
104

105
    def _from_ldiag(self, spc, ldiag):
Martin Reinecke's avatar
Martin Reinecke committed
106
107
108
109
110
111
        res = DiagonalOperator.__new__(DiagonalOperator)
        res._domain = self._domain
        if self._spaces is None or spc is None:
            res._spaces = None
        else:
            res._spaces = tuple(set(self._spaces) | set(spc))
112
        res._ldiag = ldiag
113
        res._fill_rest()
Martin Reinecke's avatar
Martin Reinecke committed
114
115
116
117
118
        return res

    def _scale(self, fct):
        if not np.isscalar(fct):
            raise TypeError("scalar value required")
119
        return self._from_ldiag((), self._ldiag*fct)
Martin Reinecke's avatar
Martin Reinecke committed
120
121
122
123

    def _add(self, sum):
        if not np.isscalar(sum):
            raise TypeError("scalar value required")
124
        return self._from_ldiag((), self._ldiag+sum)
Martin Reinecke's avatar
Martin Reinecke committed
125
126
127
128

    def _combine_prod(self, op):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
129
        return self._from_ldiag(op._spaces, self._ldiag*op._ldiag)
Martin Reinecke's avatar
Martin Reinecke committed
130
131
132
133

    def _combine_sum(self, op, selfneg, opneg):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
134
135
136
        tdiag = (self._ldiag * (-1 if selfneg else 1) +
                 op._ldiag * (-1 if opneg else 1))
        return self._from_ldiag(op._spaces, tdiag)
137

Martin Reinecke's avatar
Martin Reinecke committed
138
139
    def apply(self, x, mode):
        self._check_input(x, mode)
140
141
        # shortcut for most common cases
        if mode == 1 or (not self._complex and mode == 2):
142
            return Field.from_local_data(x.domain, x.local_data*self._ldiag)
143
144
145
146
147
148

        xdiag = self._ldiag
        if self._complex and (mode & 10):  # adjoint or inverse adjoint
            xdiag = xdiag.conj()

        if mode & 3:
149
150
            return Field.from_local_data(x.domain, x.local_data*xdiag)
        return Field.from_local_data(x.domain, x.local_data/xdiag)
151

152
    def _flip_modes(self, trafo):
Martin Reinecke's avatar
Martin Reinecke committed
153
154
        if trafo == self.ADJOINT_BIT and not self._complex:  # shortcut
            return self
155
156
157
158
159
160
        xdiag = self._ldiag
        if self._complex and (trafo & self.ADJOINT_BIT):
            xdiag = xdiag.conj()
        if trafo & self.INVERSE_BIT:
            xdiag = 1./xdiag
        return self._from_ldiag((), xdiag)
161

Martin Reinecke's avatar
Martin Reinecke committed
162
163
    def process_sample(self, samp, from_inverse):
        if (self._complex or (self._diagmin < 0.) or
164
165
                (self._diagmin == 0. and from_inverse)):
                    raise ValueError("operator not positive definite")
166
        if from_inverse:
Martin Reinecke's avatar
Martin Reinecke committed
167
            res = samp.local_data/np.sqrt(self._ldiag)
168
        else:
Martin Reinecke's avatar
Martin Reinecke committed
169
            res = samp.local_data*np.sqrt(self._ldiag)
170
        return Field.from_local_data(self._domain, res)
Martin Reinecke's avatar
Martin Reinecke committed
171
172
173
174
175
176
177

    def draw_sample(self, from_inverse=False, dtype=np.float64):
        res = Field.from_random(random_type="normal", domain=self._domain,
                                dtype=dtype)
        return self.process_sample(res, from_inverse)

    def __repr__(self):
Martin Reinecke's avatar
Martin Reinecke committed
178
        return "DiagonalOperator"