linear_operator.py 7.16 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import abc

5
6
7
8
9
10
11
12
from nifty.config import about
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities


class LinearOperator(object):
13
    __metaclass__ = abc.ABCMeta
14

15
16
17
    def __init__(self, domain=(), field_type=(), implemented=False):
        self._domain = self._parse_domain(domain)
        self._field_type = self._parse_field_type(field_type)
18
19
        self._implemented = bool(implemented)

20
21
22
23
24
25
26
27
28
29
30
    @property
    def domain(self):
        return self._domain

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

    @property
    def field_type(self):
        return self._field_type
31

32
33
34
    @abc.abstractproperty
    def field_type_target(self):
        raise NotImplementedError
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    def _parse_domain(self, domain):
        if domain is None:
            domain = ()
        elif not isinstance(domain, tuple):
            domain = (domain,)
        for d in domain:
            if not isinstance(d, Space):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object contains something that is not a "
                    "nifty.space."))
        return domain

    def _parse_field_type(self, field_type):
        if field_type is None:
            field_type = ()
        elif not isinstance(field_type, tuple):
            field_type = (field_type,)
        for ft in field_type:
            if not isinstance(ft, FieldType):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object is not a nifty.FieldType."))
        return field_type

    @property
    def implemented(self):
        return self._implemented

63
64
65
66
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

    def times(self, x, spaces=None, types=None):
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)

        y = self._times(x, spaces, types)
        return y

    def inverse_times(self, x, spaces=None, types=None):
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def adjoint_times(self, x, spaces=None, types=None):
88
89
90
        if self.unitary:
            return self.inverse_times(x, spaces, types)

91
92
93
94
95
96
97
98
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)
        y = self._adjoint_times(x, spaces, types)
        return y

    def adjoint_inverse_times(self, x, spaces=None, types=None):
99
100
101
        if self.unitary:
            return self.times(x, spaces, types)

102
103
104
105
106
107
108
109
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._adjoint_inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def inverse_adjoint_times(self, x, spaces=None, types=None):
110
111
112
        if self.unitary:
            return self.times(x, spaces, types)

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_adjoint_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def _times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'times'."))

    def _adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_times'."))

    def _inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_times'."))

    def _adjoint_inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_inverse_times'."))

    def _inverse_adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_adjoint_times'."))

    def _check_input_compatibility(self, x, spaces, types):
        if not isinstance(x, Field):
            raise ValueError(about._errors.cstring(
                "ERROR: supplied object is not a `nifty.Field`."))

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
        types = utilities.cast_axis_to_tuple(types, len(x.field_type))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
        #   be applied to a certain domain in the domain-tuple of x. This is
        #   only valid if len(self.domain)==1.
        # 2. Case:
        #   The domains of self and x match completely.

        if spaces is None:
            if self.domain != () and self.domain != x.domain:
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's domains don't "
                    "match."))
        else:
            if len(self.domain) > 1:
                raise ValueError(about._errors.cstring(
                    "ERROR: Specifying `spaces` for operators with multiple "
                    "domain spaces is not valid."))
            elif len(spaces) != len(self.domain):
                raise ValueError(about._errors.cstring(
                    "ERROR: Length of `spaces` does not match the number of "
                    "spaces in the operator's domain."))
            elif len(spaces) == 1:
                if x.domain[spaces[0]] != self.domain[0]:
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's domains don't "
                        "match."))

        if types is None:
            if self.field_type != () and self.field_type != x.field_type:
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's field_types don't "
                    "match."))
        else:
            if len(self.field_type) > 1:
                raise ValueError(about._errors.cstring(
                    "ERROR: Specifying `types` for operators with multiple "
                    "field-types is not valid."))
            elif len(types) != len(self.field_type):
                raise ValueError(about._errors.cstring(
                    "ERROR: Length of `types` does not match the number of "
                    "the operator's field-types."))
            elif len(types) == 1:
                if x.field_type[types[0]] != self.field_type[0]:
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's field_type "
                        "don't match."))
        return (spaces, types)

    def __repr__(self):
        return str(self.__class__)