fft_operator.py 5.59 KB
Newer Older
1

Jait Dixit's avatar
Jait Dixit committed
2
import nifty.nifty_utilities as utilities
3
4
5
6
7
from nifty.spaces import RGSpace,\
                         GLSpace,\
                         HPSpace,\
                         LMSpace

Jait Dixit's avatar
Jait Dixit committed
8
from nifty.operators.linear_operator import LinearOperator
9
10
11
12
13
14
from transformations import RGRGTransformation,\
                            LMGLTransformation,\
                            LMHPTransformation,\
                            GLLMTransformation,\
                            HPLMTransformation,\
                            TransformationCache
Jait Dixit's avatar
Jait Dixit committed
15
16


Jait Dixit's avatar
Jait Dixit committed
17
18
class FFTOperator(LinearOperator):

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    # ---Class attributes---

    default_codomain_dictionary = {RGSpace: RGSpace,
                                   HPSpace: LMSpace,
                                   GLSpace: LMSpace,
                                   LMSpace: HPSpace,
                                   }

    transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
                                 (HPSpace, LMSpace): HPLMTransformation,
                                 (GLSpace, LMSpace): GLLMTransformation,
                                 (LMSpace, HPSpace): LMHPTransformation,
                                 (LMSpace, GLSpace): LMGLTransformation
                                 }

Jait Dixit's avatar
Jait Dixit committed
34
35
    # ---Overwritten properties and methods---

36
37
    def __init__(self, domain=(), field_type=(), target=None, module=None):

38
39
        self._domain = self._parse_domain(domain)
        self._field_type = self._parse_field_type(field_type)
Jait Dixit's avatar
Jait Dixit committed
40

41
        # Initialize domain and target
42
        if len(self.domain) != 1:
43
            raise ValueError(
44
                    'ERROR: TransformationOperator accepts only exactly one '
45
                    'space as input domain.')
Jait Dixit's avatar
Jait Dixit committed
46
47

        if self.field_type != ():
48
            raise ValueError(
49
                'ERROR: TransformationOperator field-type must be an '
Jait Dixit's avatar
Jait Dixit committed
50
                'empty tuple.'
51
            )
Jait Dixit's avatar
Jait Dixit committed
52

53
        if target is None:
54
            target = (self.get_default_codomain(self.domain[0]), )
Jait Dixit's avatar
Jait Dixit committed
55
        self._target = self._parse_domain(target)
Jait Dixit's avatar
Jait Dixit committed
56

57
58
59
60
61
        # Create transformation instances
        try:
            forward_class = self.transformation_dictionary[
                (self.domain[0].__class__, self.target[0].__class__)]
        except KeyError:
theos's avatar
theos committed
62
            raise ValueError(
63
64
                "No forward transformation for domain-target pair "
                "found.")
65
66
67
68
        try:
            backward_class = self.transformation_dictionary[
                (self.target[0].__class__, self.domain[0].__class__)]
        except KeyError:
theos's avatar
theos committed
69
            raise ValueError(
70
71
                "No backward transformation for domain-target pair "
                "found.")
72
73
74
75
76
77

        self._forward_transformation = TransformationCache.create(
            forward_class, self.domain[0], self.target[0], module=module)

        self._backward_transformation = TransformationCache.create(
            backward_class, self.target[0], self.domain[0], module=module)
Jait Dixit's avatar
Jait Dixit committed
78

79
80
    def _times(self, x, spaces, types):
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
81
        if spaces is None:
82
83
84
85
            # this case means that x lives on only one space, which is
            # identical to the space in the domain of `self`. Otherwise the
            # input check of LinearOperator would have failed.
            axes = x.domain_axes[0]
86
87
        else:
            axes = x.domain_axes[spaces[0]]
88

89
        new_val = self._forward_transformation.transform(x.val, axes=axes)
90

91
92
93
94
95
        if spaces is None:
            result_domain = self.target
        else:
            result_domain = list(x.domain)
            result_domain[spaces[0]] = self.target[0]
96

97
98
        result_field = x.copy_empty(domain=result_domain)
        result_field.set_val(new_val=new_val)
Jait Dixit's avatar
Jait Dixit committed
99

100
        return result_field
Jait Dixit's avatar
Jait Dixit committed
101
102
103

    def _inverse_times(self, x, spaces, types):
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
104
        if spaces is None:
105
106
107
108
            # this case means that x lives on only one space, which is
            # identical to the space in the domain of `self`. Otherwise the
            # input check of LinearOperator would have failed.
            axes = x.domain_axes[0]
109
110
        else:
            axes = x.domain_axes[spaces[0]]
Jait Dixit's avatar
Jait Dixit committed
111

112
        new_val = self._backward_transformation.transform(x.val, axes=axes)
113
114
115
116
117
118
119
120
121
122
123

        if spaces is None:
            result_domain = self.domain
        else:
            result_domain = list(x.domain)
            result_domain[spaces[0]] = self.domain[0]

        result_field = x.copy_empty(domain=result_domain)
        result_field.set_val(new_val=new_val)

        return result_field
Jait Dixit's avatar
Jait Dixit committed
124
125
126

    # ---Mandatory properties and methods---

127
128
129
130
    @property
    def domain(self):
        return self._domain

Jait Dixit's avatar
Jait Dixit committed
131
132
133
134
    @property
    def target(self):
        return self._target

135
136
137
138
    @property
    def field_type(self):
        return self._field_type

Jait Dixit's avatar
Jait Dixit committed
139
140
    @property
    def field_type_target(self):
141
142
143
144
145
        return self.field_type

    @property
    def implemented(self):
        return True
Jait Dixit's avatar
Jait Dixit committed
146

147
148
149
    @property
    def unitary(self):
        return True
150
151
152
153
154
155
156
157
158

    # ---Added properties and methods---

    @classmethod
    def get_default_codomain(cls, domain):
        domain_class = domain.__class__
        try:
            codomain_class = cls.default_codomain_dictionary[domain_class]
        except KeyError:
theos's avatar
theos committed
159
            raise ValueError("Unknown domain")
160
161
162
163
164

        try:
            transform_class = cls.transformation_dictionary[(domain_class,
                                                             codomain_class)]
        except KeyError:
theos's avatar
theos committed
165
            raise ValueError(
166
                "No transformation for domain-codomain pair found.")
167
168

        return transform_class.get_codomain(domain)