linear_operator.py 4.25 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
import abc

5
from keepers import Loggable
6
7
8
9
from nifty.field import Field
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
10
class LinearOperator(Loggable, object):
11
    __metaclass__ = abc.ABCMeta
12

13
14
    def __init__(self):
        pass
15
16

    def _parse_domain(self, domain):
17
        return utilities.parse_domain(domain)
18

19
    @abc.abstractproperty
20
    def domain(self):
21
        raise NotImplementedError
22
23
24
25
26

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

27
    @abc.abstractproperty
28
    def implemented(self):
29
        raise NotImplementedError
30

31
32
33
34
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

35
36
37
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

38
39
    def times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces)
40
41
42
43

        if not self.implemented:
            x = x.weight(spaces=spaces)

44
        y = self._times(x, spaces, **kwargs)
45
46
        return y

47
48
    def inverse_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
49

50
        y = self._inverse_times(x, spaces, **kwargs)
51
52
53
54
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

55
    def adjoint_times(self, x, spaces=None, **kwargs):
56
        if self.unitary:
57
            return self.inverse_times(x, spaces)
58

59
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
60
61
62

        if not self.implemented:
            x = x.weight(spaces=spaces)
63
        y = self._adjoint_times(x, spaces, **kwargs)
64
65
        return y

66
    def adjoint_inverse_times(self, x, spaces=None, **kwargs):
67
        if self.unitary:
68
            return self.times(x, spaces)
69

70
        spaces = self._check_input_compatibility(x, spaces)
71

72
        y = self._adjoint_inverse_times(x, spaces, **kwargs)
73
74
75
76
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

77
    def inverse_adjoint_times(self, x, spaces=None, **kwargs):
78
        if self.unitary:
79
            return self.times(x, spaces, **kwargs)
80

81
        spaces = self._check_input_compatibility(x, spaces)
82

83
        y = self._inverse_adjoint_times(x, spaces)
84
85
86
87
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

88
    def _times(self, x, spaces):
89
90
        raise NotImplementedError(
            "no generic instance method 'times'.")
91

92
    def _adjoint_times(self, x, spaces):
93
94
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
95

96
    def _inverse_times(self, x, spaces):
97
98
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
99

100
    def _adjoint_inverse_times(self, x, spaces):
101
102
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
103

104
    def _inverse_adjoint_times(self, x, spaces):
105
106
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
107

108
    def _check_input_compatibility(self, x, spaces, inverse=False):
109
        if not isinstance(x, Field):
110
111
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
112
113
114
115
116
117
118
119

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
120
        #   be applied to certain spaces in the domain-tuple of x.
121
122
123
        # 2. Case:
        #   The domains of self and x match completely.

124
125
126
127
128
        if not inverse:
            self_domain = self.domain
        else:
            self_domain = self.target

129
        if spaces is None:
130
            if self_domain != x.domain:
131
132
133
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
134
        else:
135
136
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
137
138
139
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
140

141
        return spaces
142
143
144

    def __repr__(self):
        return str(self.__class__)