linear_operator.py 8.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
import abc
20
from nifty.nifty_meta import NiftyMeta
21

22
from keepers import Loggable
23
24
25
26
from nifty.field import Field
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
27
class LinearOperator(Loggable, object):
28
    __metaclass__ = NiftyMeta
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    """NIFTY base class for linear operators.
    The base NIFTY operator class is an abstract class from which
    other specific operator subclasses, including those preimplemented
    in NIFTY (e.g. the EndomorphicOperator, ProjectionOperator,
    DiagonalOperator, SmoothingOperator, ResponseOperator,
    PropagatorOperator, ComposedOperator) must be derived.

    Parameters
    ----------


    Attributes
    ----------
    domain : NIFTy.space
        The NIFTy.space in which the operator is defined.
    target : NIFTy.space
        The NIFTy.space in which the outcome of the operator lives
    unitary : boolean
        Indicates whether the operator is unitary or not


    Raises
    ------
    NotImplementedError
        Raised if
            * domain is not defined
            * target is not defined
            * unitary is not set to (True/False)

    Notes
    -----
    All Operators wihtin NIFty are linear and must therefore be a subclasses of the
    LinearOperator. A LinearOperator must have the attributes domain, target
    and unitary to be properly defined.

    Examples
    --------


    See Also
    --------
    EndomorphicOperator, ProjectionOperator,
    DiagonalOperator, SmoothingOperator, ResponseOperator,
    PropagatorOperator, ComposedOperator

    """

77
    __metaclass__ = abc.ABCMeta
78

79
80
    def __init__(self):
        pass
81
82

    def _parse_domain(self, domain):
83
        return utilities.parse_domain(domain)
84

85
    @abc.abstractproperty
86
    def domain(self):
87
        raise NotImplementedError
88
89
90
91
92

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

93
94
95
96
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

97
98
99
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

100
    def times(self, x, spaces=None, **kwargs):
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        """ Applies the Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
        x : NIFTY.Field
            applies the Operator to the given Field
        spaces : integer (default: None)
            defines on which space of the given Field the Operator acts
        **kwargs
           Additional keyword arguments get passed to the used copy_empty
           routine.

        Returns
        -------
        out : NIFTy.Field
            the processed Field living on the target space

        See Also
       --------

        """

125
        spaces = self._check_input_compatibility(x, spaces)
126

127
        y = self._times(x, spaces, **kwargs)
128
129
        return y

130
    def inverse_times(self, x, spaces=None, **kwargs):
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        """ Applies the inverse-Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
        x : NIFTY.Field
            applies the Operator to the given Field
        spaces : integer (default: None)
            defines on which space of the given Field the Operator acts
        **kwargs
           Additional keyword arguments get passed to the used copy_empty
           routine.

        Returns
        -------
        out : NIFTy.Field
            the processed Field living on the target space

        See Also
       --------

        """

155
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
156

157
        y = self._inverse_times(x, spaces, **kwargs)
158
159
        return y

160
    def adjoint_times(self, x, spaces=None, **kwargs):
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        """ Applies the adjoint-Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
        x : NIFTY.Field
            applies the Operator to the given Field
        spaces : integer (default: None)
            defines on which space of the given Field the Operator acts
        **kwargs
           Additional keyword arguments get passed to the used copy_empty
           routine.

        Returns
        -------
        out : NIFTy.Field
            the processed Field living on the target space

        See Also
       --------

        """
184
        if self.unitary:
185
            return self.inverse_times(x, spaces)
186

187
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
188

189
        y = self._adjoint_times(x, spaces, **kwargs)
190
191
        return y

192
    def adjoint_inverse_times(self, x, spaces=None, **kwargs):
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        """ Applies the adjoint-inverse Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
        x : NIFTY.Field
            applies the Operator to the given Field
        spaces : integer (default: None)
            defines on which space of the given Field the Operator acts
        **kwargs
           Additional keyword arguments get passed to the used copy_empty
           routine.

        Returns
        -------
        out : NIFTy.Field
            the processed Field living on the target space

        See Also
       --------

        """
216
        if self.unitary:
217
            return self.times(x, spaces)
218

219
        spaces = self._check_input_compatibility(x, spaces)
220

221
        y = self._adjoint_inverse_times(x, spaces, **kwargs)
222
223
        return y

224
    def inverse_adjoint_times(self, x, spaces=None, **kwargs):
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        """ Applies the inverse-adjoint Operator to a given Field.

        Operator and Field have to live over the same domain.

        Parameters
        ----------
        x : NIFTY.Field
            applies the Operator to the given Field
        spaces : integer (default: None)
            defines on which space of the given Field the Operator acts
        **kwargs
           Additional keyword arguments get passed to the used copy_empty
           routine.

        Returns
        -------
        out : NIFTy.Field
            the processed Field living on the target space

        See Also
       --------

        """
248
        if self.unitary:
249
            return self.times(x, spaces, **kwargs)
250

251
        spaces = self._check_input_compatibility(x, spaces)
252

253
        y = self._inverse_adjoint_times(x, spaces)
254
255
        return y

256
    def _times(self, x, spaces):
257
258
        raise NotImplementedError(
            "no generic instance method 'times'.")
259

260
    def _adjoint_times(self, x, spaces):
261
262
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
263

264
    def _inverse_times(self, x, spaces):
265
266
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
267

268
    def _adjoint_inverse_times(self, x, spaces):
269
270
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
271

272
    def _inverse_adjoint_times(self, x, spaces):
273
274
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
275

276
    def _check_input_compatibility(self, x, spaces, inverse=False):
277
        if not isinstance(x, Field):
278
279
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
280
281
282
283
284
285
286
287

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
288
        #   be applied to certain spaces in the domain-tuple of x.
289
290
291
        # 2. Case:
        #   The domains of self and x match completely.

292
293
294
295
296
        if not inverse:
            self_domain = self.domain
        else:
            self_domain = self.target

297
        if spaces is None:
298
            if self_domain != x.domain:
299
300
301
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
302
        else:
303
304
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
305
306
307
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
308

309
        return spaces
310
311
312

    def __repr__(self):
        return str(self.__class__)