diagonal_operator.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
20
from __future__ import division
from builtins import range
21
22
23
24
25
import numpy as np

from d2o import distributed_data_object,\
                STRATEGIES as DISTRIBUTION_STRATEGIES

Martin Reinecke's avatar
Martin Reinecke committed
26
27
28
29
from ...basic_arithmetics import log as nifty_log
from ...config import nifty_configuration as gc
from ...field import Field
from ..endomorphic_operator import EndomorphicOperator
30
31
32


class DiagonalOperator(EndomorphicOperator):
Theo Steininger's avatar
Theo Steininger committed
33
34
35
36
37
    """ NIFTY class for diagonal operators.

    The NIFTY DiagonalOperator class is a subclass derived from the
    EndomorphicOperator. It multiplies an input field pixel-wise with its
    diagonal.
38

39
40
41

    Parameters
    ----------
Theo Steininger's avatar
Theo Steininger committed
42
43
44
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
    diagonal : {scalar, list, array, Field, d2o-object}
45
46
        The diagonal entries of the operator.
    bare : boolean
Theo Steininger's avatar
Theo Steininger committed
47
48
        Indicates whether the input for the diagonal is bare or not
        (default: False).
49
50
51
52
53
    copy : boolean
        Internal copy of the diagonal (default: True)
    distribution_strategy : string
        setting the prober distribution_strategy of the
        diagonal (default : None). In case diagonal is d2o-object or Field,
Theo Steininger's avatar
Theo Steininger committed
54
        their distribution_strategy is used as a fallback.
55
56
57
    default_spaces : tuple of ints *optional*
        Defines on which space(s) of a given field the Operator acts by
        default (default: None)
58
59
60

    Attributes
    ----------
61
62
63
64
65
66
67
68
69
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
    target : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain in which the outcome of the operator lives. As the Operator
        is endomorphic this is the same as its domain.
    unitary : boolean
        Indicates whether the Operator is unitary or not.
    self_adjoint : boolean
        Indicates whether the operator is self_adjoint or not.
70
    distribution_strategy : string
Theo Steininger's avatar
Theo Steininger committed
71
72
        Defines the distribution_strategy of the distributed_data_object
        in which the diagonal entries are stored in.
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    Raises
    ------

    Notes
    -----
    The ambiguity of bare or non-bare diagonal entries is based on the choice
    of a matrix representation of the operator in question. The naive choice
    of absorbing the volume weights into the matrix leads to a matrix-vector
    calculus with the non-bare entries which seems intuitive, though.
    The choice of keeping matrix entries and volume weights separate
    deals with the bare entries that allow for correct interpretation
    of the matrix entries; e.g., as variance in case of an covariance operator.

    Examples
    --------
    >>> x_space = RGSpace(5)
Theo Steininger's avatar
Theo Steininger committed
90
91
    >>> D = DiagonalOperator(x_space, diagonal=[1., 3., 2., 4., 6.])
    >>> f = Field(x_space, val=2.)
92
93
94
    >>> res = D.times(f)
    >>> res.val
    <distributed_data_object>
Theo Steininger's avatar
Theo Steininger committed
95
    array([ 2.,  6.,  4.,  8.,  12.])
96
97
98
99
100
101
102

    See Also
    --------
    EndomorphicOperator

    """

103
104
    # ---Overwritten properties and methods---

105
106
107
108
    def __init__(self, domain=(), diagonal=None, bare=False, copy=True,
                 distribution_strategy=None, default_spaces=None):
        super(DiagonalOperator, self).__init__(default_spaces)

109
        self._domain = self._parse_domain(domain)
110

111
        if distribution_strategy is None:
112
            if isinstance(diagonal, distributed_data_object):
113
                distribution_strategy = diagonal.distribution_strategy
114
            elif isinstance(diagonal, Field):
115
                distribution_strategy = diagonal.distribution_strategy
116

117
        self._distribution_strategy = self._parse_distribution_strategy(
118
119
                               distribution_strategy=distribution_strategy,
                               val=diagonal)
120
121
122

        self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)

123
124
    def _times(self, x, spaces):
        return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
125

126
127
    def _adjoint_times(self, x, spaces):
        return self._times_helper(x, spaces,
128
                                  operation=lambda z: z.adjoint().__mul__)
129

130
    def _inverse_times(self, x, spaces):
Martin Reinecke's avatar
Martin Reinecke committed
131
        return self._times_helper(x, spaces, operation=lambda z: z.__rtruediv__)
132

133
134
    def _adjoint_inverse_times(self, x, spaces):
        return self._times_helper(x, spaces,
Martin Reinecke's avatar
Martin Reinecke committed
135
                                  operation=lambda z: z.adjoint().__rtruediv__)
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    def diagonal(self, bare=False, copy=True):
        """ Returns the diagonal of the Operator.

        Parameters
        ----------
        bare : boolean
            Whether the returned Field values should be bare or not.
        copy : boolean
            Whether the returned Field should be copied or not.

        Returns
        -------
        out : Field
            The diagonal of the Operator.

        """
        if bare:
            diagonal = self._diagonal.weight(power=-1)
        elif copy:
            diagonal = self._diagonal.copy()
        else:
            diagonal = self._diagonal
        return diagonal

    def inverse_diagonal(self, bare=False):
        """ Returns the inverse-diagonal of the operator.

        Parameters
        ----------
        bare : boolean
            Whether the returned Field values should be bare or not.

        Returns
        -------
        out : Field
            The inverse of the diagonal of the Operator.

        """
        return 1./self.diagonal(bare=bare, copy=False)

177
178
    # ---Mandatory properties and methods---

179
180
181
182
    @property
    def domain(self):
        return self._domain

183
    @property
Martin Reinecke's avatar
Martin Reinecke committed
184
185
186
187
    def self_adjoint(self):
        if self._self_adjoint is None:
            self._self_adjoint = (self._diagonal.val.imag == 0).all()
        return self._self_adjoint
188
189
190

    @property
    def unitary(self):
191
192
193
        if self._unitary is None:
            self._unitary = (self._diagonal.val *
                             self._diagonal.val.conjugate() == 1).all()
194
195
196
197
198
        return self._unitary

    # ---Added properties and methods---

    @property
199
    def distribution_strategy(self):
200
201
202
        """
        distribution_strategy : string
            Defines the way how the diagonal operator is distributed
Theo Steininger's avatar
Theo Steininger committed
203
204
            among the nodes. Available distribution_strategies are:
            'fftw', 'equal' and 'not'.
205
206
207

        Notes :
            https://arxiv.org/abs/1606.05385
Theo Steininger's avatar
Theo Steininger committed
208

209
        """
Theo Steininger's avatar
Theo Steininger committed
210

211
        return self._distribution_strategy
212

213
214
    def _parse_distribution_strategy(self, distribution_strategy, val):
        if distribution_strategy is None:
215
            if isinstance(val, distributed_data_object):
216
                distribution_strategy = val.distribution_strategy
217
            elif isinstance(val, Field):
218
                distribution_strategy = val.distribution_strategy
219
            else:
220
                self.logger.info("Datamodel set to default!")
221
222
                distribution_strategy = gc['default_distribution_strategy']
        elif distribution_strategy not in DISTRIBUTION_STRATEGIES['all']:
223
224
            raise ValueError(
                    "Invalid distribution_strategy!")
225
        return distribution_strategy
226
227

    def set_diagonal(self, diagonal, bare=False, copy=True):
228
229
230
231
        """ Sets the diagonal of the Operator.

        Parameters
        ----------
Theo Steininger's avatar
Theo Steininger committed
232
        diagonal : {scalar, list, array, Field, d2o-object}
233
234
            The diagonal entries of the operator.
        bare : boolean
Theo Steininger's avatar
Theo Steininger committed
235
236
            Indicates whether the input for the diagonal is bare or not
            (default: False).
237
        copy : boolean
Theo Steininger's avatar
Theo Steininger committed
238
            Specifies if a copy of the input shall be made (default: True).
239
240
241

        """

242
243
244
        # use the casting functionality from Field to process `diagonal`
        f = Field(domain=self.domain,
                  val=diagonal,
245
                  distribution_strategy=self.distribution_strategy,
246
247
                  copy=copy)

248
        # weight if the given values were `bare` is True
249
        # do inverse weightening if the other way around
250
        if bare:
251
252
253
            # If `copy` is True, we won't change external data by weightening
            # Otherwise, inplace weightening would change the external field
            f.weight(inplace=copy)
254

Martin Reinecke's avatar
Martin Reinecke committed
255
256
        # Reset the self_adjoint property:
        self._self_adjoint = None
257

258
259
        # Reset the unitarity property
        self._unitary = None
260
261
262

        # store the diagonal-field
        self._diagonal = f
263

264
265
    def _times_helper(self, x, spaces, operation):
        # if the domain matches directly
266
        # -> multiply the fields directly
267
        if x.domain == self.domain:
268
269
270
271
272
273
274
            # here the actual multiplication takes place
            return operation(self.diagonal(copy=False))(x)

        # if the distribution_strategy of self is sub-slice compatible to
        # the one of x, reshape the local data of self and apply it directly
        active_axes = []
        if spaces is None:
Martin Reinecke's avatar
Martin Reinecke committed
275
            active_axes = list(range(len(x.shape)))
276
277
278
279
280
281
282
283
284
285
        else:
            for space_index in spaces:
                active_axes += x.domain_axes[space_index]

        axes_local_distribution_strategy = \
            x.val.get_axes_local_distribution_strategy(active_axes)
        if axes_local_distribution_strategy == self.distribution_strategy:
            local_diagonal = self._diagonal.val.get_local_data(copy=False)
        else:
            # create an array that is sub-slice compatible
286
287
            self.logger.warn("The input field is not sub-slice compatible to "
                             "the distribution strategy of the operator.")
288
289
290
291
292
            redistr_diagonal_val = self._diagonal.val.copy(
                distribution_strategy=axes_local_distribution_strategy)
            local_diagonal = redistr_diagonal_val.get_local_data(copy=False)

        reshaper = [x.shape[i] if i in active_axes else 1
Martin Reinecke's avatar
Martin Reinecke committed
293
                    for i in range(len(x.shape))]
294
295
296
297
298
299
300
301
302
        reshaped_local_diagonal = np.reshape(local_diagonal, reshaper)

        # here the actual multiplication takes place
        local_result = operation(reshaped_local_diagonal)(
                           x.val.get_local_data(copy=False))

        result_field = x.copy_empty(dtype=local_result.dtype)
        result_field.val.set_local_data(local_result, copy=False)
        return result_field