linear_operator.py 7.27 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import abc

5
6
7
8
9
10
11
12
from nifty.config import about
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities


class LinearOperator(object):
13
    __metaclass__ = abc.ABCMeta
14

15
    def __init__(self, domain=(), field_type=()):
16
17
        self._domain = self._parse_domain(domain)
        self._field_type = self._parse_field_type(field_type)
18
19
20
21

    def _parse_domain(self, domain):
        if domain is None:
            domain = ()
22
        elif isinstance(domain, Space):
23
            domain = (domain,)
24
25
26
        elif not isinstance(domain, tuple):
            domain = tuple(domain)

27
28
29
30
31
32
33
34
35
36
        for d in domain:
            if not isinstance(d, Space):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object contains something that is not a "
                    "nifty.space."))
        return domain

    def _parse_field_type(self, field_type):
        if field_type is None:
            field_type = ()
37
        elif isinstance(field_type, FieldType):
38
            field_type = (field_type,)
39
40
41
        elif not isinstance(field_type, tuple):
            field_type = tuple(field_type)

42
43
44
45
46
47
        for ft in field_type:
            if not isinstance(ft, FieldType):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object is not a nifty.FieldType."))
        return field_type

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    @property
    def domain(self):
        return self._domain

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

    @property
    def field_type(self):
        return self._field_type

    @abc.abstractproperty
    def field_type_target(self):
        raise NotImplementedError

64
    @abc.abstractproperty
65
    def implemented(self):
66
        raise NotImplementedError
67

68
69
70
71
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

    def times(self, x, spaces=None, types=None):
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)

        y = self._times(x, spaces, types)
        return y

    def inverse_times(self, x, spaces=None, types=None):
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def adjoint_times(self, x, spaces=None, types=None):
93
94
95
        if self.unitary:
            return self.inverse_times(x, spaces, types)

96
97
98
99
100
101
102
103
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)
        y = self._adjoint_times(x, spaces, types)
        return y

    def adjoint_inverse_times(self, x, spaces=None, types=None):
104
105
106
        if self.unitary:
            return self.times(x, spaces, types)

107
108
109
110
111
112
113
114
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._adjoint_inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def inverse_adjoint_times(self, x, spaces=None, types=None):
115
116
117
        if self.unitary:
            return self.times(x, spaces, types)

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_adjoint_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def _times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'times'."))

    def _adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_times'."))

    def _inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_times'."))

    def _adjoint_inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_inverse_times'."))

    def _inverse_adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_adjoint_times'."))

    def _check_input_compatibility(self, x, spaces, types):
        if not isinstance(x, Field):
            raise ValueError(about._errors.cstring(
                "ERROR: supplied object is not a `nifty.Field`."))

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
        types = utilities.cast_axis_to_tuple(types, len(x.field_type))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
        #   be applied to a certain domain in the domain-tuple of x. This is
        #   only valid if len(self.domain)==1.
        # 2. Case:
        #   The domains of self and x match completely.

        if spaces is None:
            if self.domain != () and self.domain != x.domain:
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's domains don't "
                    "match."))
        else:
            if len(self.domain) > 1:
                raise ValueError(about._errors.cstring(
                    "ERROR: Specifying `spaces` for operators with multiple "
                    "domain spaces is not valid."))
            elif len(spaces) != len(self.domain):
                raise ValueError(about._errors.cstring(
                    "ERROR: Length of `spaces` does not match the number of "
                    "spaces in the operator's domain."))
            elif len(spaces) == 1:
                if x.domain[spaces[0]] != self.domain[0]:
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's domains don't "
                        "match."))

        if types is None:
            if self.field_type != () and self.field_type != x.field_type:
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's field_types don't "
                    "match."))
        else:
            if len(self.field_type) > 1:
                raise ValueError(about._errors.cstring(
                    "ERROR: Specifying `types` for operators with multiple "
                    "field-types is not valid."))
            elif len(types) != len(self.field_type):
                raise ValueError(about._errors.cstring(
                    "ERROR: Length of `types` does not match the number of "
                    "the operator's field-types."))
            elif len(types) == 1:
                if x.field_type[types[0]] != self.field_type[0]:
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's field_type "
                        "don't match."))
        return (spaces, types)

    def __repr__(self):
        return str(self.__class__)