linear_operator.py 6.38 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import abc

5
from keepers import Loggable
6
7
8
9
10
11
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
12
class LinearOperator(Loggable, object):
13
    __metaclass__ = abc.ABCMeta
14

15
16
    def __init__(self):
        pass
17
18
19
20

    def _parse_domain(self, domain):
        if domain is None:
            domain = ()
21
        elif isinstance(domain, Space):
22
            domain = (domain,)
23
24
25
        elif not isinstance(domain, tuple):
            domain = tuple(domain)

26
27
        for d in domain:
            if not isinstance(d, Space):
28
29
30
                raise TypeError(
                    "Given object contains something that is not a "
                    "nifty.space.")
31
32
33
34
35
        return domain

    def _parse_field_type(self, field_type):
        if field_type is None:
            field_type = ()
36
        elif isinstance(field_type, FieldType):
37
            field_type = (field_type,)
38
39
40
        elif not isinstance(field_type, tuple):
            field_type = tuple(field_type)

41
42
        for ft in field_type:
            if not isinstance(ft, FieldType):
43
44
                raise TypeError(
                    "Given object is not a nifty.FieldType.")
45
46
        return field_type

47
    @abc.abstractproperty
48
    def domain(self):
49
        raise NotImplementedError
50
51
52
53
54

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

55
    @abc.abstractproperty
56
    def field_type(self):
57
        raise NotImplementedError
58
59
60
61
62

    @abc.abstractproperty
    def field_type_target(self):
        raise NotImplementedError

63
    @abc.abstractproperty
64
    def implemented(self):
65
        raise NotImplementedError
66

67
68
69
70
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

71
72
73
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

74
    def times(self, x, spaces=None, types=None, **kwargs):
75
76
77
78
79
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)

80
        y = self._times(x, spaces, types, **kwargs)
81
82
        return y

83
    def inverse_times(self, x, spaces=None, types=None, **kwargs):
84
85
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
86

87
        y = self._inverse_times(x, spaces, types, **kwargs)
88
89
90
91
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

92
    def adjoint_times(self, x, spaces=None, types=None, **kwargs):
93
94
95
        if self.unitary:
            return self.inverse_times(x, spaces, types)

96
97
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
98
99
100

        if not self.implemented:
            x = x.weight(spaces=spaces)
101
        y = self._adjoint_times(x, spaces, types, **kwargs)
102
103
        return y

104
    def adjoint_inverse_times(self, x, spaces=None, types=None, **kwargs):
105
106
107
        if self.unitary:
            return self.times(x, spaces, types)

108
109
        spaces, types = self._check_input_compatibility(x, spaces, types)

110
        y = self._adjoint_inverse_times(x, spaces, types, **kwargs)
111
112
113
114
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

115
    def inverse_adjoint_times(self, x, spaces=None, types=None, **kwargs):
116
        if self.unitary:
117
            return self.times(x, spaces, types, **kwargs)
118

119
120
121
122
123
124
125
126
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_adjoint_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def _times(self, x, spaces, types):
127
128
        raise NotImplementedError(
            "no generic instance method 'times'.")
129
130

    def _adjoint_times(self, x, spaces, types):
131
132
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
133
134

    def _inverse_times(self, x, spaces, types):
135
136
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
137
138

    def _adjoint_inverse_times(self, x, spaces, types):
139
140
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
141
142

    def _inverse_adjoint_times(self, x, spaces, types):
143
144
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
145

146
    def _check_input_compatibility(self, x, spaces, types, inverse=False):
147
        if not isinstance(x, Field):
148
149
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
150
151
152
153
154
155
156
157
158

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
        types = utilities.cast_axis_to_tuple(types, len(x.field_type))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
159
        #   be applied to certain spaces in the domain-tuple of x.
160
161
162
        # 2. Case:
        #   The domains of self and x match completely.

163
164
165
166
167
168
169
        if not inverse:
            self_domain = self.domain
            self_field_type = self.field_type
        else:
            self_domain = self.target
            self_field_type = self.field_type_target

170
        if spaces is None:
171
            if self_domain != () and self_domain != x.domain:
172
173
174
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
175
        else:
176
177
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
178
179
180
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
181
182

        if types is None:
183
            if self_field_type != () and self_field_type != x.field_type:
184
185
186
                raise ValueError(
                    "The operator's and and field's field_types don't "
                    "match.")
187
        else:
188
            for i, field_type_index in enumerate(types):
189
                if x.field_type[field_type_index] != self_field_type[i]:
190
191
192
                    raise ValueError(
                        "The operator's and and field's field_type "
                        "don't match.")
193

194
195
196
197
        return (spaces, types)

    def __repr__(self):
        return str(self.__class__)