linear_operator.py 9.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
import abc
20
from nifty.nifty_meta import NiftyMeta
21

22
from keepers import Loggable
23
24
25
26
from nifty.field import Field
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
27
class LinearOperator(Loggable, object):
28
    """NIFTY base class for linear operators.
Theo Steininger's avatar
Theo Steininger committed
29

30
31
32
33
    The base NIFTY operator class is an abstract class from which
    other specific operator subclasses, including those preimplemented
    in NIFTY (e.g. the EndomorphicOperator, ProjectionOperator,
    DiagonalOperator, SmoothingOperator, ResponseOperator,
Theo Steininger's avatar
Theo Steininger committed
34
    PropagatorOperator, ComposedOperator) are derived.
35
36
37

    Parameters
    ----------
Theo Steininger's avatar
Theo Steininger committed
38
39
40
    default_spaces : tuple of ints *optional*
        Defines on which space(s) of a given field the Operator acts by
        default (default: None)
41
42
43

    Attributes
    ----------
Theo Steininger's avatar
Theo Steininger committed
44
45
46
47
    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain on which the Operator's input Field lives.
    target : tuple of DomainObjects, i.e. Spaces and FieldTypes
        The domain in which the Operators result lives.
48
    unitary : boolean
Theo Steininger's avatar
Theo Steininger committed
49
        Indicates whether the Operator is unitary or not.
50
51
52
53
54
55
56
57
58
59
60

    Raises
    ------
    NotImplementedError
        Raised if
            * domain is not defined
            * target is not defined
            * unitary is not set to (True/False)

    Notes
    -----
Theo Steininger's avatar
Theo Steininger committed
61
62
63
    All Operators wihtin NIFTy are linear and must therefore be a subclasses of
    the LinearOperator. A LinearOperator must have the attributes domain,
    target and unitary to be properly defined.
64
65
66
67
68
69
70
71
72

    See Also
    --------
    EndomorphicOperator, ProjectionOperator,
    DiagonalOperator, SmoothingOperator, ResponseOperator,
    PropagatorOperator, ComposedOperator

    """

73
    __metaclass__ = NiftyMeta
74

75
76
    def __init__(self, default_spaces=None):
        self.default_spaces = default_spaces
77

78
79
    @staticmethod
    def _parse_domain(domain):
80
        return utilities.parse_domain(domain)
81

82
    @abc.abstractproperty
83
    def domain(self):
84
        """
Theo Steininger's avatar
Theo Steininger committed
85
86
87
        domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
            The domain on which the Operator's input Field lives.
            Every Operator which inherits from the abstract LinearOperator
88
89
90
            base class must have this attribute.

        """
Theo Steininger's avatar
Theo Steininger committed
91

92
        raise NotImplementedError
93
94
95

    @abc.abstractproperty
    def target(self):
96
        """
Theo Steininger's avatar
Theo Steininger committed
97
98
99
        target : tuple of DomainObjects, i.e. Spaces and FieldTypes
            The domain on which the Operator's output Field lives.
            Every Operator which inherits from the abstract LinearOperator
100
101
102
            base class must have this attribute.

        """
Theo Steininger's avatar
Theo Steininger committed
103

104
105
        raise NotImplementedError

106
107
    @abc.abstractproperty
    def unitary(self):
108
109
110
111
112
113
114
        """
        unitary : boolean
            States whether the Operator is unitary or not.
            Every Operator which inherits from the abstract LinearOperator
            base class must have this attribute.

        """
Theo Steininger's avatar
Theo Steininger committed
115

116
117
        raise NotImplementedError

118
119
120
121
122
123
124
125
    @property
    def default_spaces(self):
        return self._default_spaces

    @default_spaces.setter
    def default_spaces(self, spaces):
        self._default_spaces = utilities.cast_axis_to_tuple(spaces)

126
127
128
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

Theo Steininger's avatar
Theo Steininger committed
129
    def times(self, x, spaces=None):
130
131
132
133
134
135
        """ Applies the Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
Theo Steininger's avatar
Theo Steininger committed
136
137
138
139
        x : Field
            The input Field.
        spaces : tuple of ints
            Defines on which space(s) of the given Field the Operator acts.
140
141
142

        Returns
        -------
Theo Steininger's avatar
Theo Steininger committed
143
144
        out : Field
            The processed Field living on the target-domain.
145
146
147

        """

148
        spaces = self._check_input_compatibility(x, spaces)
Theo Steininger's avatar
Theo Steininger committed
149
        y = self._times(x, spaces)
150
151
        return y

Theo Steininger's avatar
Theo Steininger committed
152
    def inverse_times(self, x, spaces=None):
153
154
155
156
157
158
        """ Applies the inverse-Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
Theo Steininger's avatar
Theo Steininger committed
159
160
161
162
        x : Field
            The input Field.
        spaces : tuple of ints
            Defines on which space(s) of the given Field the Operator acts.
163
164
165

        Returns
        -------
Theo Steininger's avatar
Theo Steininger committed
166
167
        out : Field
            The processed Field living on the target-domain.
168
169
170

        """

171
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
172

173
        try:
Theo Steininger's avatar
Theo Steininger committed
174
            y = self._inverse_times(x, spaces)
175
176
        except(NotImplementedError):
            if (self.unitary):
Theo Steininger's avatar
Theo Steininger committed
177
                y = self._adjoint_times(x, spaces)
178
179
            else:
                raise
180
181
        return y

Theo Steininger's avatar
Theo Steininger committed
182
    def adjoint_times(self, x, spaces=None):
183
184
185
186
187
188
        """ Applies the adjoint-Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
Theo Steininger's avatar
Theo Steininger committed
189
        x : Field
190
            applies the Operator to the given Field
Theo Steininger's avatar
Theo Steininger committed
191
        spaces : tuple of ints
192
193
194
195
            defines on which space of the given Field the Operator acts

        Returns
        -------
Theo Steininger's avatar
Theo Steininger committed
196
197
        out : Field
            The processed Field living on the target-domain.
198
199

        """
Theo Steininger's avatar
Theo Steininger committed
200

201
        if self.unitary:
202
            return self.inverse_times(x, spaces)
203

204
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
205

206
        try:
Theo Steininger's avatar
Theo Steininger committed
207
            y = self._adjoint_times(x, spaces)
208
209
        except(NotImplementedError):
            if (self.unitary):
Theo Steininger's avatar
Theo Steininger committed
210
                y = self._inverse_times(x, spaces)
211
212
            else:
                raise
213
214
        return y

Theo Steininger's avatar
Theo Steininger committed
215
    def adjoint_inverse_times(self, x, spaces=None):
216
217
218
219
220
221
        """ Applies the adjoint-inverse Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
Theo Steininger's avatar
Theo Steininger committed
222
        x : Field
223
            applies the Operator to the given Field
Theo Steininger's avatar
Theo Steininger committed
224
        spaces : tuple of ints
225
226
227
228
            defines on which space of the given Field the Operator acts

        Returns
        -------
Theo Steininger's avatar
Theo Steininger committed
229
230
        out : Field
            The processed Field living on the target-domain.
231

Theo Steininger's avatar
Theo Steininger committed
232
233
234
235
236
        Notes
        -----
        If the operator has an `inverse` then the inverse adjoint is identical
        to the adjoint inverse. We provide both names for convenience.

237
        See Also
Theo Steininger's avatar
Theo Steininger committed
238
        --------
239
240

        """
241

242
        spaces = self._check_input_compatibility(x, spaces)
243

244
        try:
Theo Steininger's avatar
Theo Steininger committed
245
            y = self._adjoint_inverse_times(x, spaces)
246
247
        except(NotImplementedError):
            if self.unitary:
Theo Steininger's avatar
Theo Steininger committed
248
                y = self._times(x, spaces)
249
250
            else:
                raise
251
252
        return y

Theo Steininger's avatar
Theo Steininger committed
253
254
    def inverse_adjoint_times(self, x, spaces=None):
        return self.adjoint_inverse_times(x, spaces)
255

256
    def _times(self, x, spaces):
257
258
        raise NotImplementedError(
            "no generic instance method 'times'.")
259

260
    def _adjoint_times(self, x, spaces):
261
262
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
263

264
    def _inverse_times(self, x, spaces):
265
266
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
267

268
    def _adjoint_inverse_times(self, x, spaces):
269
270
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
271

272
    def _check_input_compatibility(self, x, spaces, inverse=False):
273
        if not isinstance(x, Field):
274
            raise ValueError(
Theo Steininger's avatar
Theo Steininger committed
275
                "supplied object is not a `Field`.")
276

277
278
279
        if spaces is None:
            spaces = self.default_spaces

280
281
282
283
284
285
286
        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
287
        #   be applied to certain spaces in the domain-tuple of x.
288
289
290
        # 2. Case:
        #   The domains of self and x match completely.

291
292
293
294
295
        if not inverse:
            self_domain = self.domain
        else:
            self_domain = self.target

296
        if spaces is None:
297
            if self_domain != x.domain:
298
299
300
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
301
        else:
302
303
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
304
305
306
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
307

308
        return spaces
309
310
311

    def __repr__(self):
        return str(self.__class__)