linear_operator.py 4.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
20
import abc

21
from keepers import Loggable
22
23
24
25
from nifty.field import Field
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
26
class LinearOperator(Loggable, object):
27
    __metaclass__ = abc.ABCMeta
28

29
30
    def __init__(self):
        pass
31
32

    def _parse_domain(self, domain):
33
        return utilities.parse_domain(domain)
34

35
    @abc.abstractproperty
36
    def domain(self):
37
        raise NotImplementedError
38
39
40
41
42

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

43
    @abc.abstractproperty
44
    def implemented(self):
45
        raise NotImplementedError
46

47
48
49
50
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

51
52
53
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

54
55
    def times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces)
56
57
58
59

        if not self.implemented:
            x = x.weight(spaces=spaces)

60
        y = self._times(x, spaces, **kwargs)
61
62
        return y

63
64
    def inverse_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
65

66
        y = self._inverse_times(x, spaces, **kwargs)
67
68
69
70
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

71
    def adjoint_times(self, x, spaces=None, **kwargs):
72
        if self.unitary:
73
            return self.inverse_times(x, spaces)
74

75
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
76
77
78

        if not self.implemented:
            x = x.weight(spaces=spaces)
79
        y = self._adjoint_times(x, spaces, **kwargs)
80
81
        return y

82
    def adjoint_inverse_times(self, x, spaces=None, **kwargs):
83
        if self.unitary:
84
            return self.times(x, spaces)
85

86
        spaces = self._check_input_compatibility(x, spaces)
87

88
        y = self._adjoint_inverse_times(x, spaces, **kwargs)
89
90
91
92
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

93
    def inverse_adjoint_times(self, x, spaces=None, **kwargs):
94
        if self.unitary:
95
            return self.times(x, spaces, **kwargs)
96

97
        spaces = self._check_input_compatibility(x, spaces)
98

99
        y = self._inverse_adjoint_times(x, spaces)
100
101
102
103
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

104
    def _times(self, x, spaces):
105
106
        raise NotImplementedError(
            "no generic instance method 'times'.")
107

108
    def _adjoint_times(self, x, spaces):
109
110
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
111

112
    def _inverse_times(self, x, spaces):
113
114
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
115

116
    def _adjoint_inverse_times(self, x, spaces):
117
118
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
119

120
    def _inverse_adjoint_times(self, x, spaces):
121
122
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
123

124
    def _check_input_compatibility(self, x, spaces, inverse=False):
125
        if not isinstance(x, Field):
126
127
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
128
129
130
131
132
133
134
135

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
136
        #   be applied to certain spaces in the domain-tuple of x.
137
138
139
        # 2. Case:
        #   The domains of self and x match completely.

140
141
142
143
144
        if not inverse:
            self_domain = self.domain
        else:
            self_domain = self.target

145
        if spaces is None:
146
            if self_domain != x.domain:
147
148
149
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
150
        else:
151
152
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
153
154
155
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
156

157
        return spaces
158
159
160

    def __repr__(self):
        return str(self.__class__)