linear_operator.py 4.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
20
import abc

21
from keepers import Loggable
22
23
24
25
from nifty.field import Field
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
26
class LinearOperator(Loggable, object):
27
    __metaclass__ = abc.ABCMeta
28

29
30
    def __init__(self):
        pass
31
32

    def _parse_domain(self, domain):
33
        return utilities.parse_domain(domain)
34

35
    @abc.abstractproperty
36
    def domain(self):
37
        raise NotImplementedError
38
39
40
41
42

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

43
44
45
46
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

47
48
49
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

50
51
    def times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces)
52

53
        y = self._times(x, spaces, **kwargs)
54
55
        return y

56
57
    def inverse_times(self, x, spaces=None, **kwargs):
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
58

59
        y = self._inverse_times(x, spaces, **kwargs)
60
61
        return y

62
    def adjoint_times(self, x, spaces=None, **kwargs):
63
        if self.unitary:
64
            return self.inverse_times(x, spaces)
65

66
        spaces = self._check_input_compatibility(x, spaces, inverse=True)
67

68
        y = self._adjoint_times(x, spaces, **kwargs)
69
70
        return y

71
    def adjoint_inverse_times(self, x, spaces=None, **kwargs):
72
        if self.unitary:
73
            return self.times(x, spaces)
74

75
        spaces = self._check_input_compatibility(x, spaces)
76

77
        y = self._adjoint_inverse_times(x, spaces, **kwargs)
78
79
        return y

80
    def inverse_adjoint_times(self, x, spaces=None, **kwargs):
81
        if self.unitary:
82
            return self.times(x, spaces, **kwargs)
83

84
        spaces = self._check_input_compatibility(x, spaces)
85

86
        y = self._inverse_adjoint_times(x, spaces)
87
88
        return y

89
    def _times(self, x, spaces):
90
91
        raise NotImplementedError(
            "no generic instance method 'times'.")
92

93
    def _adjoint_times(self, x, spaces):
94
95
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
96

97
    def _inverse_times(self, x, spaces):
98
99
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
100

101
    def _adjoint_inverse_times(self, x, spaces):
102
103
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
104

105
    def _inverse_adjoint_times(self, x, spaces):
106
107
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
108

109
    def _check_input_compatibility(self, x, spaces, inverse=False):
110
        if not isinstance(x, Field):
111
112
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
113
114
115
116
117
118
119
120

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
121
        #   be applied to certain spaces in the domain-tuple of x.
122
123
124
        # 2. Case:
        #   The domains of self and x match completely.

125
126
127
128
129
        if not inverse:
            self_domain = self.domain
        else:
            self_domain = self.target

130
        if spaces is None:
131
            if self_domain != x.domain:
132
133
134
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
135
        else:
136
137
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
138
139
140
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
141

142
        return spaces
143
144
145

    def __repr__(self):
        return str(self.__class__)