linear_operator.py 6.63 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import abc

5
6
7
8
9
10
11
12
from nifty.config import about
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities


class LinearOperator(object):
13
    __metaclass__ = abc.ABCMeta
14

15
16
    def __init__(self):
        pass
17
18
19
20

    def _parse_domain(self, domain):
        if domain is None:
            domain = ()
21
        elif isinstance(domain, Space):
22
            domain = (domain,)
23
24
25
        elif not isinstance(domain, tuple):
            domain = tuple(domain)

26
27
28
29
30
31
32
33
34
35
        for d in domain:
            if not isinstance(d, Space):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object contains something that is not a "
                    "nifty.space."))
        return domain

    def _parse_field_type(self, field_type):
        if field_type is None:
            field_type = ()
36
        elif isinstance(field_type, FieldType):
37
            field_type = (field_type,)
38
39
40
        elif not isinstance(field_type, tuple):
            field_type = tuple(field_type)

41
42
43
44
45
46
        for ft in field_type:
            if not isinstance(ft, FieldType):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object is not a nifty.FieldType."))
        return field_type

47
    @abc.abstractproperty
48
    def domain(self):
49
        raise NotImplementedError
50
51
52
53
54

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

55
    @abc.abstractproperty
56
    def field_type(self):
57
        raise NotImplementedError
58
59
60
61
62

    @abc.abstractproperty
    def field_type_target(self):
        raise NotImplementedError

63
    @abc.abstractproperty
64
    def implemented(self):
65
        raise NotImplementedError
66

67
68
69
70
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

71
72
73
74
75
76
77
78
79
80
81
82
83
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

    def times(self, x, spaces=None, types=None):
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)

        y = self._times(x, spaces, types)
        return y

    def inverse_times(self, x, spaces=None, types=None):
84
85
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
86
87
88
89
90
91
92

        y = self._inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def adjoint_times(self, x, spaces=None, types=None):
93
94
95
        if self.unitary:
            return self.inverse_times(x, spaces, types)

96
97
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
98
99
100
101
102
103
104

        if not self.implemented:
            x = x.weight(spaces=spaces)
        y = self._adjoint_times(x, spaces, types)
        return y

    def adjoint_inverse_times(self, x, spaces=None, types=None):
105
106
107
        if self.unitary:
            return self.times(x, spaces, types)

108
109
110
111
112
113
114
115
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._adjoint_inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def inverse_adjoint_times(self, x, spaces=None, types=None):
116
117
118
        if self.unitary:
            return self.times(x, spaces, types)

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_adjoint_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def _times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'times'."))

    def _adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_times'."))

    def _inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_times'."))

    def _adjoint_inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_inverse_times'."))

    def _inverse_adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_adjoint_times'."))

146
    def _check_input_compatibility(self, x, spaces, types, inverse=False):
147
148
149
150
151
152
153
154
155
156
157
158
        if not isinstance(x, Field):
            raise ValueError(about._errors.cstring(
                "ERROR: supplied object is not a `nifty.Field`."))

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
        types = utilities.cast_axis_to_tuple(types, len(x.field_type))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
159
        #   be applied to certain spaces in the domain-tuple of x.
160
161
162
        # 2. Case:
        #   The domains of self and x match completely.

163
164
165
166
167
168
169
        if not inverse:
            self_domain = self.domain
            self_field_type = self.field_type
        else:
            self_domain = self.target
            self_field_type = self.field_type_target

170
        if spaces is None:
171
            if self_domain != () and self_domain != x.domain:
172
173
174
175
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's domains don't "
                    "match."))
        else:
176
177
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
178
179
180
181
182
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's domains don't "
                        "match."))

        if types is None:
183
            if self_field_type != () and self_field_type != x.field_type:
184
185
186
187
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's field_types don't "
                    "match."))
        else:
188
            for i, field_type_index in enumerate(types):
189
                if x.field_type[field_type_index] != self_field_type[i]:
190
191
192
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's field_type "
                        "don't match."))
193

194
195
196
197
        return (spaces, types)

    def __repr__(self):
        return str(self.__class__)