utilities.py 9.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Ultima's avatar
Ultima committed
18

19
from builtins import *
Ultima's avatar
Ultima committed
20
import numpy as np
21
from itertools import product
Martin Reinecke's avatar
Martin Reinecke committed
22
import abc
Martin Reinecke's avatar
Martin Reinecke committed
23
from future.utils import with_metaclass
24
from functools import reduce
25

Martin Reinecke's avatar
Martin Reinecke committed
26
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
Martin Reinecke's avatar
Martin Reinecke committed
27
           "memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
           "my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
           "my_product"]


def my_sum(terms):
    return reduce(lambda x, y: x+y, terms)


def my_lincomb_simple(terms, factors):
    terms2 = map(lambda v: v[0]*v[1], zip(terms, factors))
    return my_sum(terms2)


def my_lincomb(terms, factors):
    terms2 = map(lambda v: v[0] if v[1] == 1. else v[0]*v[1],
                 zip(terms, factors))
    return my_sum(terms2)


def my_product(iterable):
    return reduce(lambda x, y: x*y, iterable)
Martin Reinecke's avatar
Martin Reinecke committed
49

50

51
52
def get_slice_list(shape, axes):
    """
theos's avatar
theos committed
53
54
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
55
56
57
58

    Parameters
    ----------
    shape: tuple
theos's avatar
theos committed
59
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
60
    axes: tuple
theos's avatar
theos committed
61
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
62

Martin Reinecke's avatar
Martin Reinecke committed
63
64
    Yields
    ------
Jait Dixit's avatar
Jait Dixit committed
65
66
67
68
69
70
71
72
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
        If axes(axis) does not match shape.
73
    """
Martin Reinecke's avatar
Martin Reinecke committed
74
    if shape is None:
75
        raise ValueError("shape cannot be None.")
76

77
78
    if axes:
        if not all(axis < len(shape) for axis in axes):
79
            raise ValueError("axes(axis) does not match shape.")
80
        axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
Jait Dixit's avatar
Jait Dixit committed
81
        axes_iterables = \
Martin Reinecke's avatar
Martin Reinecke committed
82
            [list(range(y)) for x, y in enumerate(shape) if x not in axes]
83
84
85
86
87
        for index in product(*axes_iterables):
            it_iter = iter(index)
            slice_list = [
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
Jait Dixit's avatar
Jait Dixit committed
88
                ]
89
90
91
            yield slice_list
    else:
        yield [slice(None, None)]
Ultima's avatar
Ultima committed
92

Ultima's avatar
Ultima committed
93

94
95
96
97
98
99
100
def safe_cast(tfunc, val):
    tmp = tfunc(val)
    if val != tmp:
        raise ValueError("value changed during cast")
    return tmp


Martin Reinecke's avatar
Martin Reinecke committed
101
102
def parse_spaces(spaces, nspc):
    nspc = safe_cast(int, nspc)
103
    if spaces is None:
Martin Reinecke's avatar
Martin Reinecke committed
104
        return tuple(range(nspc))
105
106
107
108
    elif np.isscalar(spaces):
        spaces = (safe_cast(int, spaces),)
    else:
        spaces = tuple(safe_cast(int, item) for item in spaces)
109
110
    if len(spaces) == 0:
        return spaces
111
    tmp = tuple(set(spaces))
Martin Reinecke's avatar
Martin Reinecke committed
112
    if tmp[0] < 0 or tmp[-1] >= nspc:
113
114
115
116
        raise ValueError("space index out of range")
    if len(tmp) != len(spaces):
        raise ValueError("multiply defined space indices")
    return spaces
Martin Reinecke's avatar
Martin Reinecke committed
117
118


119
120
121
122
123
124
125
126
127
128
129
def infer_space(domain, space):
    if space is None:
        if len(domain) != 1:
            raise ValueError("need a Field with exactly one domain")
        space = 0
    space = int(space)
    if space < 0 or space >= len(domain):
        raise ValueError("space index out of range")
    return space


Martin Reinecke's avatar
Martin Reinecke committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def memo(f):
    name = f.__name__

    def wrapped_f(self):
        if not hasattr(self, "_cache"):
            self._cache = {}
        try:
            return self._cache[name]
        except KeyError:
            self._cache[name] = f(self)
            return self._cache[name]
    return wrapped_f


class _DocStringInheritor(type):
    """
    A variation on
    http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
    by Paul McGuire
    """
    def __new__(meta, name, bases, clsdict):
        if not('__doc__' in clsdict and clsdict['__doc__']):
            for mro_cls in (mro_cls for base in bases
                            for mro_cls in base.mro()):
                doc = mro_cls.__doc__
                if doc:
                    clsdict['__doc__'] = doc
                    break
        for attr, attribute in list(clsdict.items()):
            if not attribute.__doc__:
                for mro_cls in (mro_cls for base in bases
                                for mro_cls in base.mro()
                                if hasattr(mro_cls, attr)):
                    doc = getattr(getattr(mro_cls, attr), '__doc__')
                    if doc:
                        if isinstance(attribute, property):
                            clsdict[attr] = property(attribute.fget,
                                                     attribute.fset,
                                                     attribute.fdel,
                                                     doc)
                        else:
                            attribute.__doc__ = doc
                        break
        return super(_DocStringInheritor, meta).__new__(meta, name,
                                                        bases, clsdict)


class NiftyMeta(_DocStringInheritor, abc.ABCMeta):
    pass
Martin Reinecke's avatar
Martin Reinecke committed
179
180


Martin Reinecke's avatar
Martin Reinecke committed
181
182
183
184
def NiftyMetaBase():
    return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))


Martin Reinecke's avatar
Martin Reinecke committed
185
186
187
188
189
def nthreads():
    if nthreads._val is None:
        import os
        nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1"))
    return nthreads._val
190
191


Martin Reinecke's avatar
Martin Reinecke committed
192
nthreads._val = None
Martin Reinecke's avatar
Martin Reinecke committed
193

Martin Reinecke's avatar
Martin Reinecke committed
194
195
196
197
198
199
200
201
202
203
204
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE')


def fft_prep():
    import pyfftw
    pyfftw.interfaces.cache.enable()
    pyfftw.interfaces.cache.set_keepalive_time(1000.)

Martin Reinecke's avatar
Martin Reinecke committed
205

Martin Reinecke's avatar
Martin Reinecke committed
206
207
208
209
210
def hartley(a, axes=None):
    # Check if the axes provided are valid given the shape
    if axes is not None and \
            not all(axis < len(a.shape) for axis in axes):
        raise ValueError("Provided axes do not match array shape")
211
    if np.issubdtype(a.dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
212
        raise TypeError("Hartley transform requires real-valued arrays.")
Martin Reinecke's avatar
Martin Reinecke committed
213
214

    from pyfftw.interfaces.numpy_fft import rfftn
Martin Reinecke's avatar
Martin Reinecke committed
215
    tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
Martin Reinecke's avatar
Martin Reinecke committed
216

Martin Reinecke's avatar
Martin Reinecke committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    def _fill_array(tmp, res, axes):
        if axes is None:
            axes = tuple(range(tmp.ndim))
        lastaxis = axes[-1]
        ntmplast = tmp.shape[lastaxis]
        slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
        np.add(tmp.real, tmp.imag, out=res[slice1])

        def _fill_upper_half(tmp, res, axes):
            lastaxis = axes[-1]
            nlast = res.shape[lastaxis]
            ntmplast = tmp.shape[lastaxis]
            nrem = nlast - ntmplast
            slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
            slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
            for i in axes[:-1]:
                slice1[i] = slice(1, None)
                slice2[i] = slice(None, 0, -1)
            np.subtract(tmp[slice2].real, tmp[slice2].imag, out=res[slice1])
            for i, ax in enumerate(axes[:-1]):
                dim1 = [slice(None)]*ax + [slice(0, 1)]
                axes2 = axes[:i] + axes[i+1:]
                _fill_upper_half(tmp[dim1], res[dim1], axes2)

        _fill_upper_half(tmp, res, axes)
        return res
Martin Reinecke's avatar
Martin Reinecke committed
243

Martin Reinecke's avatar
Martin Reinecke committed
244
    return _fill_array(tmp, np.empty_like(a), axes)
Martin Reinecke's avatar
Martin Reinecke committed
245
246
247
248
249
250
251
252


# Do a real-to-complex forward FFT and return the _full_ output array
def my_fftn_r2c(a, axes=None):
    # Check if the axes provided are valid given the shape
    if axes is not None and \
            not all(axis < len(a.shape) for axis in axes):
        raise ValueError("Provided axes do not match array shape")
253
    if np.issubdtype(a.dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
254
255
256
        raise TypeError("Transform requires real-valued input arrays.")

    from pyfftw.interfaces.numpy_fft import rfftn
Martin Reinecke's avatar
Martin Reinecke committed
257
    tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
Martin Reinecke's avatar
Martin Reinecke committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    def _fill_complex_array(tmp, res, axes):
        if axes is None:
            axes = tuple(range(tmp.ndim))
        lastaxis = axes[-1]
        ntmplast = tmp.shape[lastaxis]
        slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
        res[slice1] = tmp

        def _fill_upper_half_complex(tmp, res, axes):
            lastaxis = axes[-1]
            nlast = res.shape[lastaxis]
            ntmplast = tmp.shape[lastaxis]
            nrem = nlast - ntmplast
            slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
            slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
            for i in axes[:-1]:
                slice1[i] = slice(1, None)
                slice2[i] = slice(None, 0, -1)
            # np.conjugate(tmp[slice2], out=res[slice1])
            res[slice1] = np.conjugate(tmp[slice2])
            for i, ax in enumerate(axes[:-1]):
                dim1 = [slice(None)]*ax + [slice(0, 1)]
                axes2 = axes[:i] + axes[i+1:]
                _fill_upper_half_complex(tmp[dim1], res[dim1], axes2)

        _fill_upper_half_complex(tmp, res, axes)
        return res

    return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
Martin Reinecke's avatar
Martin Reinecke committed
288
289
290
291
292


def my_fftn(a, axes=None):
    from pyfftw.interfaces.numpy_fft import fftn
    return fftn(a, axes=axes, **_fft_extra_args)