fft_operator.py 10.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18

19
20
import numpy as np

Jait Dixit's avatar
Jait Dixit committed
21
import nifty.nifty_utilities as utilities
22
23
24
25
26
from nifty.spaces import RGSpace,\
                         GLSpace,\
                         HPSpace,\
                         LMSpace

Jait Dixit's avatar
Jait Dixit committed
27
from nifty.operators.linear_operator import LinearOperator
28
29
30
31
32
33
from transformations import RGRGTransformation,\
                            LMGLTransformation,\
                            LMHPTransformation,\
                            GLLMTransformation,\
                            HPLMTransformation,\
                            TransformationCache
Jait Dixit's avatar
Jait Dixit committed
34
35


Jait Dixit's avatar
Jait Dixit committed
36
class FFTOperator(LinearOperator):
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    """ Transforms between a pair of position and harmonic domains.
    Possible domain pairs are
      - a harmonic and a non-harmonic RGSpace (with matching distances)
      - a HPSpace and a LMSpace
      - a GLSpace and a LMSpace
    Within a domain pair, both orderings are possible.

    The operator provides a "times" and an "adjoint_times" operation.
    For a pair of RGSpaces, the "adjoint_times" operation is equivalent to
    "inverse_times"; for the sphere-related domains this is not the case, since
    the operator matrix is not square.

    Parameters
    ----------

    domain: Space or single-element tuple of Spaces
        The domain of the data that is input by "times" and output by
        "adjoint_times".

    target: Space  or single-element tuple of Spaces (optional)
        The domain of the data that is output by "times" and input by
        "adjoint_times".
        If omitted, a co-domain will be chosen automatically.
        Whenever "domain" is an RGSpace, the codomain (and its parameters) are
        uniquely determined. For GLSpace, HPSpace, and LMSpace, a sensible
        (but not unique) co-domain is chosen that should work satisfactorily in
        most situations, but for full control, the user should explicitly
        specify a codomain.
    module: String (optional)
        Software module employed for carrying out the transform operations.
        For RGSpace pairs this can be "numpy" or "fftw", where "numpy" is always
        available, but "fftw" offers higher performance and parallelization.
        For sphere-related domains, only "pyHealpix" is available.
        If omitted, "fftw" is selected for RGSpaces if available, else "numpy";
        on the sphere the default is (unsurprisingly) "pyHealpix".
    domain_dtype: data type (optional)
        Data type of the fields that go into "times" and come out of
        "adjoint_times". Default is "numpy.float".
    target_dtype: data type (optional)
        Data type of the fields that go into "adjoint_times" and come out of
        "times". Default is "numpy.complex".
        (MR: I feel this is not really a good idea, since it makes no sense for
        SHTs. Also, wouldn't it make sense to specify data types
        only to "times" and "adjoint_times"? Does the operator itself really
        need to know this, or only the individual call?)

    Attributes
    ----------

    domain: Tuple of Spaces (with one entry)
        The domain of the data that is input by "times" and output by
        "adjoint_times".
    target: Tuple of Spaces (with one entry)
        The domain of the data that is output by "times" and input by
        "adjoint_times".
    unitary: bool
        Returns False.
        This is strictly speaking a lie, because FFTOperators on RGSpaces are
        in fact unitary ... but if we return True in this case, then
        LinearOperator will call _inverse_times instead of _adjoint_times, which
        does not exist. This needs some more work.

    Raises
    ------

    ValueError:
        if "domain" or "target" are not of the proper type.
    """
105
106
107
108
109
    # ---Class attributes---

    default_codomain_dictionary = {RGSpace: RGSpace,
                                   HPSpace: LMSpace,
                                   GLSpace: LMSpace,
110
                                   LMSpace: GLSpace,
111
112
113
114
115
116
117
118
119
                                   }

    transformation_dictionary = {(RGSpace, RGSpace): RGRGTransformation,
                                 (HPSpace, LMSpace): HPLMTransformation,
                                 (GLSpace, LMSpace): GLLMTransformation,
                                 (LMSpace, HPSpace): LMHPTransformation,
                                 (LMSpace, GLSpace): LMGLTransformation
                                 }

Jait Dixit's avatar
Jait Dixit committed
120
121
    # ---Overwritten properties and methods---

122
    def __init__(self, domain, target=None, module=None,
123
                 domain_dtype=None, target_dtype=None):
124
125

        # Initialize domain and target
126
127

        self._domain = self._parse_domain(domain)
128
        if len(self.domain) != 1:
129
130
            raise ValueError("TransformationOperator accepts only exactly one "
                             "space as input domain.")
Jait Dixit's avatar
Jait Dixit committed
131

132
        if target is None:
133
            target = (self.get_default_codomain(self.domain[0]), )
Jait Dixit's avatar
Jait Dixit committed
134
        self._target = self._parse_domain(target)
135
136
137
        if len(self.target) != 1:
            raise ValueError("TransformationOperator accepts only exactly one "
                             "space as output target.")
Jait Dixit's avatar
Jait Dixit committed
138

139
        # Create transformation instances
140
        forward_class = self.transformation_dictionary[
141
                (self.domain[0].__class__, self.target[0].__class__)]
142
        backward_class = self.transformation_dictionary[
143
144
145
146
147
148
149
                (self.target[0].__class__, self.domain[0].__class__)]

        self._forward_transformation = TransformationCache.create(
            forward_class, self.domain[0], self.target[0], module=module)

        self._backward_transformation = TransformationCache.create(
            backward_class, self.target[0], self.domain[0], module=module)
Jait Dixit's avatar
Jait Dixit committed
150

151
152
        #MR FIXME: these defaults do not work for SHTs as they are currently
        #   implemented. Should have either float or complex on both sides.
Martin Reinecke's avatar
Martin Reinecke committed
153
154
155
156
        #   It would be great if this could be harmonized in some way.
        #   The simplest (and maybe safest, but expensive) solution would be
        #   to allow only complex-valued fields on both sides of an FFT between
        #   RGSpaces.
157
158
        # Store the dtype information
        if domain_dtype is None:
159
160
            self.logger.info("Setting domain_dtype to np.float.")
            self.domain_dtype = np.float
161
162
163
164
        else:
            self.domain_dtype = np.dtype(domain_dtype)

        if target_dtype is None:
165
166
            self.logger.info("Setting target_dtype to np.complex.")
            self.target_dtype = np.complex
167
168
169
170
        else:
            self.target_dtype = np.dtype(target_dtype)

    def _times(self, x, spaces):
171
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
172
        if spaces is None:
173
174
175
176
            # this case means that x lives on only one space, which is
            # identical to the space in the domain of `self`. Otherwise the
            # input check of LinearOperator would have failed.
            axes = x.domain_axes[0]
177
178
        else:
            axes = x.domain_axes[spaces[0]]
179

180
        new_val = self._forward_transformation.transform(x.val, axes=axes)
181

182
183
184
185
186
        if spaces is None:
            result_domain = self.target
        else:
            result_domain = list(x.domain)
            result_domain[spaces[0]] = self.target[0]
187

188
189
        result_field = x.copy_empty(domain=result_domain,
                                    dtype=self.target_dtype)
190
        result_field.set_val(new_val=new_val, copy=False)
Jait Dixit's avatar
Jait Dixit committed
191

192
        return result_field
Jait Dixit's avatar
Jait Dixit committed
193

194
    def _adjoint_times(self, x, spaces):
Jait Dixit's avatar
Jait Dixit committed
195
        spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
196
        if spaces is None:
197
198
199
200
            # this case means that x lives on only one space, which is
            # identical to the space in the domain of `self`. Otherwise the
            # input check of LinearOperator would have failed.
            axes = x.domain_axes[0]
201
202
        else:
            axes = x.domain_axes[spaces[0]]
Jait Dixit's avatar
Jait Dixit committed
203

204
        new_val = self._backward_transformation.transform(x.val, axes=axes)
205
206
207
208
209
210
211

        if spaces is None:
            result_domain = self.domain
        else:
            result_domain = list(x.domain)
            result_domain[spaces[0]] = self.domain[0]

212
213
        result_field = x.copy_empty(domain=result_domain,
                                    dtype=self.domain_dtype)
214
        result_field.set_val(new_val=new_val, copy=False)
215
216

        return result_field
Jait Dixit's avatar
Jait Dixit committed
217
218
219

    # ---Mandatory properties and methods---

220
221
222
223
    @property
    def domain(self):
        return self._domain

Jait Dixit's avatar
Jait Dixit committed
224
225
226
227
    @property
    def target(self):
        return self._target

228
229
230
    @property
    def implemented(self):
        return True
Jait Dixit's avatar
Jait Dixit committed
231

232
233
    @property
    def unitary(self):
234
        return False
235
236
237
238
239

    # ---Added properties and methods---

    @classmethod
    def get_default_codomain(cls, domain):
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        """ Returns a codomain to the given domain.

        Parameters
        ----------

        domain: Space
            An instance of RGSpace, HPSpace, GLSpace or LMSpace.

        Returns
        -------

        target: Space
            A (more or less perfect) counterpart to "domain" with respect
            to a FFT operation.
            Whenever "domain" is an RGSpace, the codomain (and its parameters)
            are uniquely determined. For GLSpace, HPSpace, and LMSpace, a
            sensible (but not unique) co-domain is chosen that should work
            satisfactorily in most situations. For full control however, the
            user should not rely on this method.

        Raises
        ------

        ValueError:
            if no default codomain is defined for "domain".
        """
266
267
268
269
        domain_class = domain.__class__
        try:
            codomain_class = cls.default_codomain_dictionary[domain_class]
        except KeyError:
270
            raise ValueError("Unknown domain")
271
272
273
274
275

        try:
            transform_class = cls.transformation_dictionary[(domain_class,
                                                             codomain_class)]
        except KeyError:
276
            raise ValueError(
277
                "No transformation for domain-codomain pair found.")
278
279

        return transform_class.get_codomain(domain)