linear_operator.py 6.38 KB
Newer Older
1 2
# -*- coding: utf-8 -*-

3 4
import abc

5
from keepers import Loggable
6 7 8 9 10 11
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities


theos's avatar
theos committed
12
class LinearOperator(Loggable, object):
13
    __metaclass__ = abc.ABCMeta
14

15 16
    def __init__(self):
        pass
17 18 19 20

    def _parse_domain(self, domain):
        if domain is None:
            domain = ()
21
        elif isinstance(domain, Space):
22
            domain = (domain,)
23 24 25
        elif not isinstance(domain, tuple):
            domain = tuple(domain)

26 27
        for d in domain:
            if not isinstance(d, Space):
28 29 30
                raise TypeError(
                    "Given object contains something that is not a "
                    "nifty.space.")
31 32 33 34 35
        return domain

    def _parse_field_type(self, field_type):
        if field_type is None:
            field_type = ()
36
        elif isinstance(field_type, FieldType):
37
            field_type = (field_type,)
38 39 40
        elif not isinstance(field_type, tuple):
            field_type = tuple(field_type)

41 42
        for ft in field_type:
            if not isinstance(ft, FieldType):
43 44
                raise TypeError(
                    "Given object is not a nifty.FieldType.")
45 46
        return field_type

47
    @abc.abstractproperty
48
    def domain(self):
49
        raise NotImplementedError
50 51 52 53 54

    @abc.abstractproperty
    def target(self):
        raise NotImplementedError

55
    @abc.abstractproperty
56
    def field_type(self):
57
        raise NotImplementedError
58 59 60 61 62

    @abc.abstractproperty
    def field_type_target(self):
        raise NotImplementedError

63
    @abc.abstractproperty
64
    def implemented(self):
65
        raise NotImplementedError
66

67 68 69 70
    @abc.abstractproperty
    def unitary(self):
        raise NotImplementedError

71 72 73
    def __call__(self, *args, **kwargs):
        return self.times(*args, **kwargs)

74
    def times(self, x, spaces=None, types=None, **kwargs):
75 76 77 78 79
        spaces, types = self._check_input_compatibility(x, spaces, types)

        if not self.implemented:
            x = x.weight(spaces=spaces)

80
        y = self._times(x, spaces, types, **kwargs)
81 82
        return y

83
    def inverse_times(self, x, spaces=None, types=None, **kwargs):
84 85
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
86

87
        y = self._inverse_times(x, spaces, types, **kwargs)
88 89 90 91
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

92
    def adjoint_times(self, x, spaces=None, types=None, **kwargs):
93 94 95
        if self.unitary:
            return self.inverse_times(x, spaces, types)

96 97
        spaces, types = self._check_input_compatibility(x, spaces, types,
                                                        inverse=True)
98 99 100

        if not self.implemented:
            x = x.weight(spaces=spaces)
101
        y = self._adjoint_times(x, spaces, types, **kwargs)
102 103
        return y

104
    def adjoint_inverse_times(self, x, spaces=None, types=None, **kwargs):
105 106 107
        if self.unitary:
            return self.times(x, spaces, types)

108 109
        spaces, types = self._check_input_compatibility(x, spaces, types)

110
        y = self._adjoint_inverse_times(x, spaces, types, **kwargs)
111 112 113 114
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

115
    def inverse_adjoint_times(self, x, spaces=None, types=None, **kwargs):
116
        if self.unitary:
117
            return self.times(x, spaces, types, **kwargs)
118

119 120 121 122 123 124 125 126
        spaces, types = self._check_input_compatibility(x, spaces, types)

        y = self._inverse_adjoint_times(x, spaces, types)
        if not self.implemented:
            y = y.weight(power=-1, spaces=spaces)
        return y

    def _times(self, x, spaces, types):
127 128
        raise NotImplementedError(
            "no generic instance method 'times'.")
129 130

    def _adjoint_times(self, x, spaces, types):
131 132
        raise NotImplementedError(
            "no generic instance method 'adjoint_times'.")
133 134

    def _inverse_times(self, x, spaces, types):
135 136
        raise NotImplementedError(
            "no generic instance method 'inverse_times'.")
137 138

    def _adjoint_inverse_times(self, x, spaces, types):
139 140
        raise NotImplementedError(
            "no generic instance method 'adjoint_inverse_times'.")
141 142

    def _inverse_adjoint_times(self, x, spaces, types):
143 144
        raise NotImplementedError(
            "no generic instance method 'inverse_adjoint_times'.")
145

146
    def _check_input_compatibility(self, x, spaces, types, inverse=False):
147
        if not isinstance(x, Field):
148 149
            raise ValueError(
                "supplied object is not a `nifty.Field`.")
150 151 152 153 154 155 156 157 158

        # sanitize the `spaces` and `types` input
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
        types = utilities.cast_axis_to_tuple(types, len(x.field_type))

        # if the operator's domain is set to something, there are two valid
        # cases:
        # 1. Case:
        #   The user specifies with `spaces` that the operators domain should
159
        #   be applied to certain spaces in the domain-tuple of x.
160 161 162
        # 2. Case:
        #   The domains of self and x match completely.

163 164 165 166 167 168 169
        if not inverse:
            self_domain = self.domain
            self_field_type = self.field_type
        else:
            self_domain = self.target
            self_field_type = self.field_type_target

170
        if spaces is None:
171
            if self_domain != () and self_domain != x.domain:
172 173 174
                raise ValueError(
                    "The operator's and and field's domains don't "
                    "match.")
175
        else:
176 177
            for i, space_index in enumerate(spaces):
                if x.domain[space_index] != self_domain[i]:
178 179 180
                    raise ValueError(
                        "The operator's and and field's domains don't "
                        "match.")
181 182

        if types is None:
183
            if self_field_type != () and self_field_type != x.field_type:
184 185 186
                raise ValueError(
                    "The operator's and and field's field_types don't "
                    "match.")
187
        else:
188
            for i, field_type_index in enumerate(types):
189
                if x.field_type[field_type_index] != self_field_type[i]:
190 191 192
                    raise ValueError(
                        "The operator's and and field's field_type "
                        "don't match.")
193

194 195 196 197
        return (spaces, types)

    def __repr__(self):
        return str(self.__class__)