diagonal_operator.py 6.73 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
import numpy as np
19 20

from .. import dobj, utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
22
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
23
from .endomorphic_operator import EndomorphicOperator
Martin Reinecke's avatar
Martin Reinecke committed
24

25 26

class DiagonalOperator(EndomorphicOperator):
27
    """Represents a :class:`LinearOperator` which is diagonal.
Theo Steininger's avatar
Theo Steininger committed
28

Martin Reinecke's avatar
Martin Reinecke committed
29
    The NIFTy DiagonalOperator class is a subclass derived from the
Philipp Arras's avatar
Docs  
Philipp Arras committed
30 31
    :class:`EndomorphicOperator`. It multiplies an input field pixel-wise with
    its diagonal.
32

33 34
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
35
    diagonal : Field
Martin Reinecke's avatar
docs  
Martin Reinecke committed
36 37
        The diagonal entries of the operator.
    domain : Domain, tuple of Domain or DomainTuple, optional
Philipp Arras's avatar
Philipp Arras committed
38
        The domain on which the Operator's input Field is defined.
39
        If None, use the domain of "diagonal".
Martin Reinecke's avatar
docs  
Martin Reinecke committed
40
    spaces : int or tuple of int, optional
41 42
        The elements of "domain" on which the operator acts.
        If None, it acts on all elements.
Martin Reinecke's avatar
Martin Reinecke committed
43 44 45 46 47 48 49 50 51 52

    Notes
    -----
    Formally, this operator always supports all operation modes (times,
    adjoint_times, inverse_times and inverse_adjoint_times), even if there
    are diagonal elements with value 0 or infinity. It is the user's
    responsibility to apply the operator only in appropriate ways (e.g. call
    inverse_times only if there are no zeros on the diagonal).

    This shortcoming will hopefully be fixed in the future.
53 54
    """

Martin Reinecke's avatar
Martin Reinecke committed
55
    def __init__(self, diagonal, domain=None, spaces=None):
Martin Reinecke's avatar
Martin Reinecke committed
56 57
        if not isinstance(diagonal, Field):
            raise TypeError("Field object required")
58 59 60 61 62 63
        if domain is None:
            self._domain = diagonal.domain
        else:
            self._domain = DomainTuple.make(domain)
        if spaces is None:
            self._spaces = None
Martin Reinecke's avatar
Martin Reinecke committed
64
            if diagonal.domain != self._domain:
65 66
                raise ValueError("domain mismatch")
        else:
67 68
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            if len(self._spaces) != len(diagonal.domain):
69
                raise ValueError("spaces and domain must have the same length")
Martin Reinecke's avatar
Martin Reinecke committed
70
            for i, j in enumerate(self._spaces):
71 72
                if diagonal.domain[i] != self._domain[j]:
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
73
            if self._spaces == tuple(range(len(self._domain))):
Martin Reinecke's avatar
tweak  
Martin Reinecke committed
74 75 76 77 78 79 80
                self._spaces = None  # shortcut

        if self._spaces is not None:
            active_axes = []
            for space_index in self._spaces:
                active_axes += self._domain.axes[space_index]

Martin Reinecke's avatar
Martin Reinecke committed
81
            if self._spaces[0] == 0:
Martin Reinecke's avatar
Martin Reinecke committed
82
                self._ldiag = diagonal.local_data
Martin Reinecke's avatar
Martin Reinecke committed
83
            else:
Martin Reinecke's avatar
Martin Reinecke committed
84
                self._ldiag = diagonal.to_global_data()
Martin Reinecke's avatar
Martin Reinecke committed
85
            locshape = dobj.local_shape(self._domain.shape, 0)
Martin Reinecke's avatar
tweak  
Martin Reinecke committed
86
            self._reshaper = [shp if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
87 88 89
                              for i, shp in enumerate(locshape)]
            self._ldiag = self._ldiag.reshape(self._reshaper)
        else:
Martin Reinecke's avatar
Martin Reinecke committed
90
            self._ldiag = diagonal.local_data
91
        self._fill_rest()
92

93
    def _fill_rest(self):
Martin Reinecke's avatar
Martin Reinecke committed
94
        self._ldiag.flags.writeable = False
Martin Reinecke's avatar
Martin Reinecke committed
95
        self._complex = utilities.iscomplextype(self._ldiag.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
96
        self._capability = self._all_ops
97
        if not self._complex:
98 99
            lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
            self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
Martin Reinecke's avatar
Martin Reinecke committed
100

101
    def _from_ldiag(self, spc, ldiag):
Martin Reinecke's avatar
Martin Reinecke committed
102 103 104 105 106 107
        res = DiagonalOperator.__new__(DiagonalOperator)
        res._domain = self._domain
        if self._spaces is None or spc is None:
            res._spaces = None
        else:
            res._spaces = tuple(set(self._spaces) | set(spc))
Martin Reinecke's avatar
Martin Reinecke committed
108
        res._ldiag = np.array(ldiag)
109
        res._fill_rest()
Martin Reinecke's avatar
Martin Reinecke committed
110 111 112 113 114
        return res

    def _scale(self, fct):
        if not np.isscalar(fct):
            raise TypeError("scalar value required")
115
        return self._from_ldiag((), self._ldiag*fct)
Martin Reinecke's avatar
Martin Reinecke committed
116 117 118 119

    def _add(self, sum):
        if not np.isscalar(sum):
            raise TypeError("scalar value required")
120
        return self._from_ldiag((), self._ldiag+sum)
Martin Reinecke's avatar
Martin Reinecke committed
121 122 123 124

    def _combine_prod(self, op):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
125
        return self._from_ldiag(op._spaces, self._ldiag*op._ldiag)
Martin Reinecke's avatar
Martin Reinecke committed
126 127 128 129

    def _combine_sum(self, op, selfneg, opneg):
        if not isinstance(op, DiagonalOperator):
            raise TypeError("DiagonalOperator required")
130 131 132
        tdiag = (self._ldiag * (-1 if selfneg else 1) +
                 op._ldiag * (-1 if opneg else 1))
        return self._from_ldiag(op._spaces, tdiag)
133

Martin Reinecke's avatar
Martin Reinecke committed
134 135
    def apply(self, x, mode):
        self._check_input(x, mode)
136 137
        # shortcut for most common cases
        if mode == 1 or (not self._complex and mode == 2):
138
            return Field.from_local_data(x.domain, x.local_data*self._ldiag)
139 140 141 142 143 144

        xdiag = self._ldiag
        if self._complex and (mode & 10):  # adjoint or inverse adjoint
            xdiag = xdiag.conj()

        if mode & 3:
145 146
            return Field.from_local_data(x.domain, x.local_data*xdiag)
        return Field.from_local_data(x.domain, x.local_data/xdiag)
147

148
    def _flip_modes(self, trafo):
Martin Reinecke's avatar
Martin Reinecke committed
149 150
        if trafo == self.ADJOINT_BIT and not self._complex:  # shortcut
            return self
151 152 153 154 155 156
        xdiag = self._ldiag
        if self._complex and (trafo & self.ADJOINT_BIT):
            xdiag = xdiag.conj()
        if trafo & self.INVERSE_BIT:
            xdiag = 1./xdiag
        return self._from_ldiag((), xdiag)
157

Martin Reinecke's avatar
Martin Reinecke committed
158 159
    def process_sample(self, samp, from_inverse):
        if (self._complex or (self._diagmin < 0.) or
160
                (self._diagmin == 0. and from_inverse)):
Martin Reinecke's avatar
Martin Reinecke committed
161
            raise ValueError("operator not positive definite")
162
        if from_inverse:
Martin Reinecke's avatar
Martin Reinecke committed
163
            res = samp.local_data/np.sqrt(self._ldiag)
164
        else:
Martin Reinecke's avatar
Martin Reinecke committed
165
            res = samp.local_data*np.sqrt(self._ldiag)
166
        return Field.from_local_data(self._domain, res)
Martin Reinecke's avatar
Martin Reinecke committed
167 168 169 170 171 172 173

    def draw_sample(self, from_inverse=False, dtype=np.float64):
        res = Field.from_random(random_type="normal", domain=self._domain,
                                dtype=dtype)
        return self.process_sample(res, from_inverse)

    def __repr__(self):
Martin Reinecke's avatar
Martin Reinecke committed
174
        return "DiagonalOperator"