linear_operator.py 5.57 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import abc

5
from keepers import Loggable
6
7
8
9
from nifty.field import Field
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
10
class LinearOperator(Loggable, object):
11
    __metaclass__ = abc.ABCMeta
12

13
14
    def __init__(self):
        pass
15
16

    def _parse_domain(self, domain):
17
        return utilities.parse_domain(domain)
18
19

    def _parse_field_type(self, field_type):
20
        return utilities.parse_field_type(field_type)
21

22
    @abc.abstractproperty
23
    def domain(self):
24
        raise NotImplementedError
25
26
27
28
29

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

30
    @abc.abstractproperty
31
    def field_type(self):
32
        raise NotImplementedError
33
34
35
36
37

    @abc.abstractproperty
    def field_type_target(self):
        raise NotImplementedError

38
    @abc.abstractproperty
39
    def implemented(self):
40
        raise NotImplementedError
41

42
43
44
45
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

46
47
48
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

49
    def times(self, x, spaces=None, types=None, **kwargs):
50
51
52
53
54
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)

55
        y = self._times(x, spaces, types, **kwargs)
56
57
        return y

58
    def inverse_times(self, x, spaces=None, types=None, **kwargs):
59
60
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
61

62
        y = self._inverse_times(x, spaces, types, **kwargs)
63
64
65
66
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

67
    def adjoint_times(self, x, spaces=None, types=None, **kwargs):
68
69
70
        if self.unitary:
            return self.inverse_times(x, spaces, types)

71
72
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
73
74
75

        if not self.implemented:
            x = x.weight(spaces=spaces)
76
        y = self._adjoint_times(x, spaces, types, **kwargs)
77
78
        return y

79
    def adjoint_inverse_times(self, x, spaces=None, types=None, **kwargs):
80
81
82
        if self.unitary:
            return self.times(x, spaces, types)

83
84
        spaces, types = self._check_input_compatibility(x, spaces, types)

85
        y = self._adjoint_inverse_times(x, spaces, types, **kwargs)
86
87
88
89
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

90
    def inverse_adjoint_times(self, x, spaces=None, types=None, **kwargs):
91
        if self.unitary:
92
            return self.times(x, spaces, types, **kwargs)
93

94
95
96
97
98
99
100
101
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_adjoint_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def _times(self, x, spaces, types):
102
103
        raise NotImplementedError(
            "no generic instance method 'times'.")
104
105

    def _adjoint_times(self, x, spaces, types):
106
107
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
108
109

    def _inverse_times(self, x, spaces, types):
110
111
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
112
113

    def _adjoint_inverse_times(self, x, spaces, types):
114
115
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
116
117

    def _inverse_adjoint_times(self, x, spaces, types):
118
119
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
120

121
    def _check_input_compatibility(self, x, spaces, types, inverse=False):
122
        if not isinstance(x, Field):
123
124
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
125
126
127
128
129
130
131
132
133

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
        types = utilities.cast_axis_to_tuple(types, len(x.field_type))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
134
        #   be applied to certain spaces in the domain-tuple of x.
135
136
137
        # 2. Case:
        #   The domains of self and x match completely.

138
139
140
141
142
143
144
        if not inverse:
            self_domain = self.domain
            self_field_type = self.field_type
        else:
            self_domain = self.target
            self_field_type = self.field_type_target

145
        if spaces is None:
146
            if self_domain != () and self_domain != x.domain:
147
148
149
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
150
        else:
151
152
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
153
154
155
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
156
157

        if types is None:
158
            if self_field_type != () and self_field_type != x.field_type:
159
160
161
                raise ValueError(
                    "The operator's and and field's field_types don't "
                    "match.")
162
        else:
163
            for i, field_type_index in enumerate(types):
164
                if x.field_type[field_type_index] != self_field_type[i]:
165
166
167
                    raise ValueError(
                        "The operator's and and field's field_type "
                        "don't match.")
168

169
170
171
172
        return (spaces, types)

    def __repr__(self):
        return str(self.__class__)