linear_operator.py 4.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
import abc
20
from nifty.nifty_meta import NiftyMeta
21

22
from keepers import Loggable
23
24
25
26
from nifty.field import Field
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
27
class LinearOperator(Loggable, object):
28
    __metaclass__ = NiftyMeta
29

30
31
    def __init__(self):
        pass
32
33

    def _parse_domain(self, domain):
34
        return utilities.parse_domain(domain)
35

36
    @abc.abstractproperty
37
    def domain(self):
38
        raise NotImplementedError
39
40
41
42
43

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

44
45
46
47
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

48
49
50
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

51
52
    def times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces)
53

54
        y = self._times(x, spaces, **kwargs)
55
56
        return y

57
58
    def inverse_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
59

60
        y = self._inverse_times(x, spaces, **kwargs)
61
62
        return y

63
    def adjoint_times(self, x, spaces=None, **kwargs):
64
        if self.unitary:
65
            return self.inverse_times(x, spaces)
66

67
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
68

69
        y = self._adjoint_times(x, spaces, **kwargs)
70
71
        return y

72
    def adjoint_inverse_times(self, x, spaces=None, **kwargs):
73
        if self.unitary:
74
            return self.times(x, spaces)
75

76
        spaces = self._check_input_compatibility(x, spaces)
77

78
        y = self._adjoint_inverse_times(x, spaces, **kwargs)
79
80
        return y

81
    def inverse_adjoint_times(self, x, spaces=None, **kwargs):
82
        if self.unitary:
83
            return self.times(x, spaces, **kwargs)
84

85
        spaces = self._check_input_compatibility(x, spaces)
86

87
        y = self._inverse_adjoint_times(x, spaces)
88
89
        return y

90
    def _times(self, x, spaces):
91
92
        raise NotImplementedError(
            "no generic instance method 'times'.")
93

94
    def _adjoint_times(self, x, spaces):
95
96
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
97

98
    def _inverse_times(self, x, spaces):
99
100
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
101

102
    def _adjoint_inverse_times(self, x, spaces):
103
104
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
105

106
    def _inverse_adjoint_times(self, x, spaces):
107
108
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
109

110
    def _check_input_compatibility(self, x, spaces, inverse=False):
111
        if not isinstance(x, Field):
112
113
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
114
115
116
117
118
119
120
121

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
122
        #   be applied to certain spaces in the domain-tuple of x.
123
124
125
        # 2. Case:
        #   The domains of self and x match completely.

126
127
128
129
130
        if not inverse:
            self_domain = self.domain
        else:
            self_domain = self.target

131
        if spaces is None:
132
            if self_domain != x.domain:
133
134
135
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
136
        else:
137
138
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
139
140
141
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
142

143
        return spaces
144
145
146

    def __repr__(self):
        return str(self.__class__)