linear_operator.py 4.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
import abc
20
from nifty.nifty_meta import NiftyMeta
21

22
from keepers import Loggable
23
24
25
26
from nifty.field import Field
import nifty.nifty_utilities as utilities


27
class LinearOperator(Loggable, object):
28
    __metaclass__ = NiftyMeta
29

30
31
    def __init__(self):
        pass
32
33

    def _parse_domain(self, domain):
34
        return utilities.parse_domain(domain)
35

36
    @abc.abstractproperty
37
    def domain(self):
38
        raise NotImplementedError
39
40
41
42
43

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

44
45
46
47
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

48
49
50
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

51
52
    def times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces)
53

54
        y = self._times(x, spaces, **kwargs)
55
56
        return y

57
58
    def inverse_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
59

60
61
62
63
64
65
66
        try:
            y = self._inverse_times(x, spaces, **kwargs)
        except(NotImplementedError):
            if (self.unitary):
                y = self._adjoint_times(x, spaces, **kwargs)
            else:
                raise
67
68
        return y

69
70
    def adjoint_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
71

72
73
74
75
76
77
78
        try:
            y = self._adjoint_times(x, spaces, **kwargs)
        except(NotImplementedError):
            if (self.unitary):
                y = self._inverse_times(x, spaces, **kwargs)
            else:
                raise
79
80
        return y

81
82
    def adjoint_inverse_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces)
83

84
85
86
87
88
89
90
        try:
            y = self._adjoint_inverse_times(x, spaces, **kwargs)
        except(NotImplementedError):
            if self.unitary:
                y = self._times(x, spaces, **kwargs)
            else:
                raise
91
92
        return y

93
94
    def inverse_adjoint_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces)
95

96
97
98
99
100
101
102
        try:
            y = self._inverse_adjoint_times(x, spaces, **kwargs)
        except(NotImplementedError):
            if self.unitary:
                y = self._times(x, spaces, **kwargs)
            else:
                raise
103
104
        return y

105
    def _times(self, x, spaces):
106
107
        raise NotImplementedError(
            "no generic instance method 'times'.")
108

109
    def _adjoint_times(self, x, spaces):
110
111
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
112

113
    def _inverse_times(self, x, spaces):
114
115
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
116

117
    def _adjoint_inverse_times(self, x, spaces):
118
119
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
120

121
    def _inverse_adjoint_times(self, x, spaces):
122
123
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
124

125
    def _check_input_compatibility(self, x, spaces, inverse=False):
126
        if not isinstance(x, Field):
127
128
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
129
130
131
132
133
134
135
136

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
137
        #   be applied to certain spaces in the domain-tuple of x.
138
139
140
        # 2. Case:
        #   The domains of self and x match completely.

141
142
143
144
145
        if not inverse:
            self_domain = self.domain
        else:
            self_domain = self.target

146
        if spaces is None:
147
            if self_domain != x.domain:
148
149
150
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
151
        else:
152
153
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
154
155
156
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
157

158
        return spaces
159
160
161

    def __repr__(self):
        return str(self.__class__)