linear_operator.py 6.72 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import abc

5
6
7
8
9
10
11
12
from nifty.config import about
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities


class LinearOperator(object):
13
    __metaclass__ = abc.ABCMeta
14

15
    def __init__(self, domain=(), field_type=()):
16
17
        self._domain = self._parse_domain(domain)
        self._field_type = self._parse_field_type(field_type)
18
19
20
21

    def _parse_domain(self, domain):
        if domain is None:
            domain = ()
22
        elif isinstance(domain, Space):
23
            domain = (domain,)
24
25
26
        elif not isinstance(domain, tuple):
            domain = tuple(domain)

27
28
29
30
31
32
33
34
35
36
        for d in domain:
            if not isinstance(d, Space):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object contains something that is not a "
                    "nifty.space."))
        return domain

    def _parse_field_type(self, field_type):
        if field_type is None:
            field_type = ()
37
        elif isinstance(field_type, FieldType):
38
            field_type = (field_type,)
39
40
41
        elif not isinstance(field_type, tuple):
            field_type = tuple(field_type)

42
43
44
45
46
47
        for ft in field_type:
            if not isinstance(ft, FieldType):
                raise TypeError(about._errors.cstring(
                    "ERROR: Given object is not a nifty.FieldType."))
        return field_type

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    @property
    def domain(self):
        return self._domain

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

    @property
    def field_type(self):
        return self._field_type

    @abc.abstractproperty
    def field_type_target(self):
        raise NotImplementedError

64
    @abc.abstractproperty
65
    def implemented(self):
66
        raise NotImplementedError
67

68
69
70
71
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

72
73
74
75
76
77
78
79
80
81
82
83
84
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

    def times(self, x, spaces=None, types=None):
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)

        y = self._times(x, spaces, types)
        return y

    def inverse_times(self, x, spaces=None, types=None):
85
86
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
87
88
89
90
91
92
93

        y = self._inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def adjoint_times(self, x, spaces=None, types=None):
94
95
96
        if self.unitary:
            return self.inverse_times(x, spaces, types)

97
98
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
99
100
101
102
103
104
105

        if not self.implemented:
            x = x.weight(spaces=spaces)
        y = self._adjoint_times(x, spaces, types)
        return y

    def adjoint_inverse_times(self, x, spaces=None, types=None):
106
107
108
        if self.unitary:
            return self.times(x, spaces, types)

109
110
111
112
113
114
115
116
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._adjoint_inverse_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def inverse_adjoint_times(self, x, spaces=None, types=None):
117
118
119
        if self.unitary:
            return self.times(x, spaces, types)

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_adjoint_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def _times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'times'."))

    def _adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_times'."))

    def _inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_times'."))

    def _adjoint_inverse_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'adjoint_inverse_times'."))

    def _inverse_adjoint_times(self, x, spaces, types):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: no generic instance method 'inverse_adjoint_times'."))

147
    def _check_input_compatibility(self, x, spaces, types, inverse=False):
148
149
150
151
152
153
154
155
156
157
158
159
        if not isinstance(x, Field):
            raise ValueError(about._errors.cstring(
                "ERROR: supplied object is not a `nifty.Field`."))

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
        types = utilities.cast_axis_to_tuple(types, len(x.field_type))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
160
        #   be applied to certain spaces in the domain-tuple of x.
161
162
163
        # 2. Case:
        #   The domains of self and x match completely.

164
165
166
167
168
169
170
        if not inverse:
            self_domain = self.domain
            self_field_type = self.field_type
        else:
            self_domain = self.target
            self_field_type = self.field_type_target

171
        if spaces is None:
172
            if self_domain != () and self_domain != x.domain:
173
174
175
176
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's domains don't "
                    "match."))
        else:
177
178
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
179
180
181
182
183
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's domains don't "
                        "match."))

        if types is None:
184
            if self_field_type != () and self_field_type != x.field_type:
185
186
187
188
                raise ValueError(about._errors.cstring(
                    "ERROR: The operator's and and field's field_types don't "
                    "match."))
        else:
189
190
            for i, field_type_index in enumerate(types):
                if x.field_types[field_type_index] != self_field_type[i]:
191
192
193
                    raise ValueError(about._errors.cstring(
                        "ERROR: The operator's and and field's field_type "
                        "don't match."))
194

195
196
197
198
        return (spaces, types)

    def __repr__(self):
        return str(self.__class__)