diagonal_operator.py 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15 16 17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

19
from __future__ import absolute_import, division, print_function
20

21
import numpy as np
22 23 24

from .. import dobj, utilities
from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
25
from ..domain_tuple import DomainTuple
26
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
27
from .endomorphic_operator import EndomorphicOperator
Martin Reinecke's avatar
Martin Reinecke committed
28

29 30

class DiagonalOperator(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
31
    """ NIFTy class for diagonal operators.
Theo Steininger's avatar
Theo Steininger committed
32

Martin Reinecke's avatar
Martin Reinecke committed
33
    The NIFTy DiagonalOperator class is a subclass derived from the
Theo Steininger's avatar
Theo Steininger committed
34 35
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
36

37 38
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
39
    diagonal : Field
Martin Reinecke's avatar
docs  
Martin Reinecke committed
40 41
        The diagonal entries of the operator.
    domain : Domain, tuple of Domain or DomainTuple, optional
42 43
        The domain on which the Operator's input Field lives.
        If None, use the domain of "diagonal".
Martin Reinecke's avatar
docs  
Martin Reinecke committed
44
    spaces : int or tuple of int, optional
45 46
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
Martin Reinecke's avatar
Martin Reinecke committed
47 48 49 50 51 52 53 54 55 56

    Notes
    -----
    Formally, this operator always supports all operation modes (times,
    adjoint_times, inverse_times and inverse_adjoint_times), even if there
    are diagonal elements with value 0 or infinity. It is the user's
    responsibility to apply the operator only in appropriate ways (e.g. call
    inverse_times only if there are no zeros on the diagonal).

    This shortcoming will hopefully be fixed in the future.
57 58
    """

Martin Reinecke's avatar
Martin Reinecke committed
59
    def __init__(self, diagonal, domain=None, spaces=None):
60
        super(DiagonalOperator, self).__init__()
61

Martin Reinecke's avatar
Martin Reinecke committed
62 63
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
64 65 66 67 68 69
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
70
            if diagonal.domain is not self._domain:
71 72
                raise ValueError("domain mismatch")
        else:
73 74
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            if len(self._spaces) != len(diagonal.domain):
75
                raise ValueError("spaces and domain must have the same length")
Martin Reinecke's avatar
Martin Reinecke committed
76
            for i, j in enumerate(self._spaces):
77 78
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
79
            if self._spaces == tuple(range(len(self._domain))):
Martin Reinecke's avatar
tweak  
Martin Reinecke committed
80 81 82 83 84 85 86
                self._spaces = None  # shortcut

        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

Martin Reinecke's avatar
Martin Reinecke committed
87
            if self._spaces[0] == 0:
Martin Reinecke's avatar
Martin Reinecke committed
88
                self._ldiag = diagonal.local_data
Martin Reinecke's avatar
Martin Reinecke committed
89
            else:
Martin Reinecke's avatar
Martin Reinecke committed
90
                self._ldiag = diagonal.to_global_data()
Martin Reinecke's avatar
Martin Reinecke committed
91
            locshape = dobj.local_shape(self._domain.shape, 0)
Martin Reinecke's avatar
tweak  
Martin Reinecke committed
92
            self._reshaper = [shp if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
93 94 95
                              for i, shp in enumerate(locshape)]
            self._ldiag = self._ldiag.reshape(self._reshaper)
        else:
Martin Reinecke's avatar
Martin Reinecke committed
96
            self._ldiag = diagonal.local_data
97 98 99
        self._update_diagmin()

    def _update_diagmin(self):
Martin Reinecke's avatar
Martin Reinecke committed
100
        self._ldiag.flags.writeable = False
101 102 103
        if not np.issubdtype(self._ldiag.dtype, np.complexfloating):
            lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
            self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
Martin Reinecke's avatar
Martin Reinecke committed
104

105
    def _from_ldiag(self, spc, ldiag):
Martin Reinecke's avatar
Martin Reinecke committed
106 107 108 109 110 111
        res = DiagonalOperator.__new__(DiagonalOperator)
        res._domain = self._domain
        if self._spaces is None or spc is None:
            res._spaces = None
        else:
            res._spaces = tuple(set(self._spaces) | set(spc))
112 113
        res._ldiag = ldiag
        res._update_diagmin()
Martin Reinecke's avatar
Martin Reinecke committed
114 115 116 117 118
        return res

    def _scale(self, fct):
        if not np.isscalar(fct):
            raise TypeError("scalar value required")
119
        return self._from_ldiag((), self._ldiag*fct)
Martin Reinecke's avatar
Martin Reinecke committed
120 121 122 123

    def _add(self, sum):
        if not np.isscalar(sum):
            raise TypeError("scalar value required")
124
        return self._from_ldiag((), self._ldiag+sum)
Martin Reinecke's avatar
Martin Reinecke committed
125 126 127 128

    def _combine_prod(self, op):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
129
        return self._from_ldiag(op._spaces, self._ldiag*op._ldiag)
Martin Reinecke's avatar
Martin Reinecke committed
130 131 132 133

    def _combine_sum(self, op, selfneg, opneg):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
134 135 136
        tdiag = (self._ldiag * (-1 if selfneg else 1) +
                 op._ldiag * (-1 if opneg else 1))
        return self._from_ldiag(op._spaces, tdiag)
137

Martin Reinecke's avatar
Martin Reinecke committed
138 139
    def apply(self, x, mode):
        self._check_input(x, mode)
140

Martin Reinecke's avatar
Martin Reinecke committed
141 142 143 144 145 146 147 148 149 150 151 152 153 154
        if mode == self.TIMES:
            return Field(x.domain, val=x.val*self._ldiag)
        elif mode == self.ADJOINT_TIMES:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val*self._ldiag)
            else:
                return Field(x.domain, val=x.val*self._ldiag.conj())
        elif mode == self.INVERSE_TIMES:
            return Field(x.domain, val=x.val/self._ldiag)
        else:
            if np.issubdtype(self._ldiag.dtype, np.floating):
                return Field(x.domain, val=x.val/self._ldiag)
            else:
                return Field(x.domain, val=x.val/self._ldiag.conj())
155

156 157
    @property
    def domain(self):
158
        return self._domain
159

160
    @property
Martin Reinecke's avatar
Martin Reinecke committed
161
    def capability(self):
Martin Reinecke's avatar
Martin Reinecke committed
162 163
        return self._all_ops

164 165 166 167 168
    def _flip_modes(self, trafo):
        ADJ = self.ADJOINT_BIT
        INV = self.INVERSE_BIT

        if trafo == 0:
Martin Reinecke's avatar
Martin Reinecke committed
169
            return self
170
        if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating):
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
171
            return self
172
        if trafo == ADJ:
173
            return self._from_ldiag((), self._ldiag.conjugate())
174
        elif trafo == INV:
175
            return self._from_ldiag((), 1./self._ldiag)
176
        elif trafo == ADJ | INV:
177 178
            return self._from_ldiag((), 1./self._ldiag.conjugate())
        raise ValueError("invalid operator transformation")
179

180
    def draw_sample(self, from_inverse=False, dtype=np.float64):
181
        if np.issubdtype(self._ldiag.dtype, np.complexfloating):
clienhar's avatar
clienhar committed
182
            raise ValueError("operator not positive definite")
183 184 185 186
        if self._diagmin < 0.:
            raise ValueError("operator not positive definite")
        if self._diagmin == 0. and from_inverse:
            raise ValueError("operator not positive definite")
187 188
        res = Field.from_random(random_type="normal", domain=self._domain,
                                dtype=dtype)
189
        if from_inverse:
190
            return res/np.sqrt(self._ldiag)
191
        else:
192
            return res*np.sqrt(self._ldiag)