utilities.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Ultima's avatar
Ultima committed
18

19
from __future__ import absolute_import, division, print_function
20

21
import collections
22
23
24
25
from itertools import product

import numpy as np
from future.utils import with_metaclass
Martin Reinecke's avatar
Martin Reinecke committed
26
27
import pyfftw
from pyfftw.interfaces.numpy_fft import rfftn, fftn
28
29

from .compat import *
30

Martin Reinecke's avatar
Martin Reinecke committed
31
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
Martin Reinecke's avatar
Martin Reinecke committed
32
           "memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
33
           "my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
Martin Reinecke's avatar
Martin Reinecke committed
34
           "my_product", "frozendict", "special_add_at", "iscomplextype"]
35
36


Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
37
38
def my_sum(iterable):
    return reduce(lambda x, y: x+y, iterable)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


def my_lincomb_simple(terms, factors):
    terms2 = map(lambda v: v[0]*v[1], zip(terms, factors))
    return my_sum(terms2)


def my_lincomb(terms, factors):
    terms2 = map(lambda v: v[0] if v[1] == 1. else v[0]*v[1],
                 zip(terms, factors))
    return my_sum(terms2)


def my_product(iterable):
    return reduce(lambda x, y: x*y, iterable)
Martin Reinecke's avatar
Martin Reinecke committed
54

55

56
57
def get_slice_list(shape, axes):
    """
theos's avatar
theos committed
58
59
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
60
61
62
63

    Parameters
    ----------
    shape: tuple
theos's avatar
theos committed
64
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
65
    axes: tuple
theos's avatar
theos committed
66
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
67

Martin Reinecke's avatar
Martin Reinecke committed
68
69
    Yields
    ------
Jait Dixit's avatar
Jait Dixit committed
70
71
72
73
74
75
76
77
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
        If axes(axis) does not match shape.
78
    """
Martin Reinecke's avatar
Martin Reinecke committed
79
    if shape is None:
80
        raise ValueError("shape cannot be None.")
81

82
83
    if axes:
        if not all(axis < len(shape) for axis in axes):
84
            raise ValueError("axes(axis) does not match shape.")
Martin Reinecke's avatar
Martin Reinecke committed
85
        axes_select = [0 if x in axes else 1 for x in range(len(shape))]
Jait Dixit's avatar
Jait Dixit committed
86
        axes_iterables = \
Martin Reinecke's avatar
Martin Reinecke committed
87
            [list(range(y)) for x, y in enumerate(shape) if x not in axes]
88
89
        for index in product(*axes_iterables):
            it_iter = iter(index)
90
            slice_list = tuple(
91
92
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
93
            )
94
95
96
            yield slice_list
    else:
        yield [slice(None, None)]
Ultima's avatar
Ultima committed
97

Ultima's avatar
Ultima committed
98

99
100
101
102
103
104
105
def safe_cast(tfunc, val):
    tmp = tfunc(val)
    if val != tmp:
        raise ValueError("value changed during cast")
    return tmp


Martin Reinecke's avatar
Martin Reinecke committed
106
107
def parse_spaces(spaces, nspc):
    nspc = safe_cast(int, nspc)
108
    if spaces is None:
Martin Reinecke's avatar
Martin Reinecke committed
109
        return tuple(range(nspc))
110
111
112
113
    elif np.isscalar(spaces):
        spaces = (safe_cast(int, spaces),)
    else:
        spaces = tuple(safe_cast(int, item) for item in spaces)
114
115
    if len(spaces) == 0:
        return spaces
116
    tmp = tuple(set(spaces))
Martin Reinecke's avatar
Martin Reinecke committed
117
    if tmp[0] < 0 or tmp[-1] >= nspc:
118
119
120
121
        raise ValueError("space index out of range")
    if len(tmp) != len(spaces):
        raise ValueError("multiply defined space indices")
    return spaces
Martin Reinecke's avatar
Martin Reinecke committed
122
123


124
125
126
127
128
129
130
131
132
133
134
def infer_space(domain, space):
    if space is None:
        if len(domain) != 1:
            raise ValueError("need a Field with exactly one domain")
        space = 0
    space = int(space)
    if space < 0 or space >= len(domain):
        raise ValueError("space index out of range")
    return space


Martin Reinecke's avatar
Martin Reinecke committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def memo(f):
    name = f.__name__

    def wrapped_f(self):
        if not hasattr(self, "_cache"):
            self._cache = {}
        try:
            return self._cache[name]
        except KeyError:
            self._cache[name] = f(self)
            return self._cache[name]
    return wrapped_f


class _DocStringInheritor(type):
    """
    A variation on
    http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
    by Paul McGuire
    """
    def __new__(meta, name, bases, clsdict):
        if not('__doc__' in clsdict and clsdict['__doc__']):
            for mro_cls in (mro_cls for base in bases
                            for mro_cls in base.mro()):
                doc = mro_cls.__doc__
                if doc:
                    clsdict['__doc__'] = doc
                    break
Martin Reinecke's avatar
Martin Reinecke committed
163
        for attr, attribute in clsdict.items():
Martin Reinecke's avatar
Martin Reinecke committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            if not attribute.__doc__:
                for mro_cls in (mro_cls for base in bases
                                for mro_cls in base.mro()
                                if hasattr(mro_cls, attr)):
                    doc = getattr(getattr(mro_cls, attr), '__doc__')
                    if doc:
                        if isinstance(attribute, property):
                            clsdict[attr] = property(attribute.fget,
                                                     attribute.fset,
                                                     attribute.fdel,
                                                     doc)
                        else:
                            attribute.__doc__ = doc
                        break
        return super(_DocStringInheritor, meta).__new__(meta, name,
                                                        bases, clsdict)


Martin Reinecke's avatar
Martin Reinecke committed
182
class NiftyMeta(_DocStringInheritor):
Martin Reinecke's avatar
Martin Reinecke committed
183
    pass
Martin Reinecke's avatar
Martin Reinecke committed
184
185


Martin Reinecke's avatar
Martin Reinecke committed
186
187
188
189
def NiftyMetaBase():
    return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))


Martin Reinecke's avatar
Martin Reinecke committed
190
191
192
193
194
def nthreads():
    if nthreads._val is None:
        import os
        nthreads._val = int(os.getenv("OMP_NUM_THREADS", "1"))
    return nthreads._val
195
196


Martin Reinecke's avatar
Martin Reinecke committed
197
nthreads._val = None
Martin Reinecke's avatar
Martin Reinecke committed
198

Martin Reinecke's avatar
Martin Reinecke committed
199
200
201
202
203
204
205
# Optional extra arguments for the FFT calls
# _fft_extra_args = {}
# if exact reproducibility is needed, use this:
_fft_extra_args = dict(planner_effort='FFTW_ESTIMATE')


def fft_prep():
Martin Reinecke's avatar
Martin Reinecke committed
206
207
208
209
210
    if not fft_prep._initialized:
        pyfftw.interfaces.cache.enable()
        pyfftw.interfaces.cache.set_keepalive_time(1000.)
        fft_prep._initialized = True
fft_prep._initialized = False
Martin Reinecke's avatar
Martin Reinecke committed
211

Martin Reinecke's avatar
Martin Reinecke committed
212

Martin Reinecke's avatar
Martin Reinecke committed
213
214
215
216
217
def hartley(a, axes=None):
    # Check if the axes provided are valid given the shape
    if axes is not None and \
            not all(axis < len(a.shape) for axis in axes):
        raise ValueError("Provided axes do not match array shape")
Martin Reinecke's avatar
Martin Reinecke committed
218
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
Martin Reinecke committed
219
        raise TypeError("Hartley transform requires real-valued arrays.")
Martin Reinecke's avatar
Martin Reinecke committed
220

Martin Reinecke's avatar
Martin Reinecke committed
221
    tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
Martin Reinecke's avatar
Martin Reinecke committed
222

Martin Reinecke's avatar
Martin Reinecke committed
223
224
225
226
227
    def _fill_array(tmp, res, axes):
        if axes is None:
            axes = tuple(range(tmp.ndim))
        lastaxis = axes[-1]
        ntmplast = tmp.shape[lastaxis]
228
        slice1 = (slice(None),)*lastaxis + (slice(0, ntmplast),)
Martin Reinecke's avatar
Martin Reinecke committed
229
230
231
232
233
234
235
236
237
238
239
240
        np.add(tmp.real, tmp.imag, out=res[slice1])

        def _fill_upper_half(tmp, res, axes):
            lastaxis = axes[-1]
            nlast = res.shape[lastaxis]
            ntmplast = tmp.shape[lastaxis]
            nrem = nlast - ntmplast
            slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
            slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
            for i in axes[:-1]:
                slice1[i] = slice(1, None)
                slice2[i] = slice(None, 0, -1)
241
242
            slice1 = tuple(slice1)
            slice2 = tuple(slice2)
Martin Reinecke's avatar
Martin Reinecke committed
243
244
            np.subtract(tmp[slice2].real, tmp[slice2].imag, out=res[slice1])
            for i, ax in enumerate(axes[:-1]):
245
                dim1 = (slice(None),)*ax + (slice(0, 1),)
Martin Reinecke's avatar
Martin Reinecke committed
246
247
248
249
250
                axes2 = axes[:i] + axes[i+1:]
                _fill_upper_half(tmp[dim1], res[dim1], axes2)

        _fill_upper_half(tmp, res, axes)
        return res
Martin Reinecke's avatar
Martin Reinecke committed
251

Martin Reinecke's avatar
Martin Reinecke committed
252
    return _fill_array(tmp, np.empty_like(a), axes)
Martin Reinecke's avatar
Martin Reinecke committed
253
254
255
256
257
258
259
260


# Do a real-to-complex forward FFT and return the _full_ output array
def my_fftn_r2c(a, axes=None):
    # Check if the axes provided are valid given the shape
    if axes is not None and \
            not all(axis < len(a.shape) for axis in axes):
        raise ValueError("Provided axes do not match array shape")
Martin Reinecke's avatar
Martin Reinecke committed
261
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
Martin Reinecke committed
262
263
        raise TypeError("Transform requires real-valued input arrays.")

Martin Reinecke's avatar
Martin Reinecke committed
264
    tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
Martin Reinecke's avatar
Martin Reinecke committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

    def _fill_complex_array(tmp, res, axes):
        if axes is None:
            axes = tuple(range(tmp.ndim))
        lastaxis = axes[-1]
        ntmplast = tmp.shape[lastaxis]
        slice1 = [slice(None)]*lastaxis + [slice(0, ntmplast)]
        res[slice1] = tmp

        def _fill_upper_half_complex(tmp, res, axes):
            lastaxis = axes[-1]
            nlast = res.shape[lastaxis]
            ntmplast = tmp.shape[lastaxis]
            nrem = nlast - ntmplast
            slice1 = [slice(None)]*lastaxis + [slice(ntmplast, None)]
            slice2 = [slice(None)]*lastaxis + [slice(nrem, 0, -1)]
            for i in axes[:-1]:
                slice1[i] = slice(1, None)
                slice2[i] = slice(None, 0, -1)
            # np.conjugate(tmp[slice2], out=res[slice1])
            res[slice1] = np.conjugate(tmp[slice2])
            for i, ax in enumerate(axes[:-1]):
                dim1 = [slice(None)]*ax + [slice(0, 1)]
                axes2 = axes[:i] + axes[i+1:]
                _fill_upper_half_complex(tmp[dim1], res[dim1], axes2)

        _fill_upper_half_complex(tmp, res, axes)
        return res

    return _fill_complex_array(tmp, np.empty_like(a, dtype=tmp.dtype), axes)
Martin Reinecke's avatar
Martin Reinecke committed
295
296
297
298


def my_fftn(a, axes=None):
    return fftn(a, axes=axes, **_fft_extra_args)
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338


class frozendict(collections.Mapping):
    """
    An immutable wrapper around dictionaries that implements the complete
    :py:class:`collections.Mapping` interface. It can be used as a drop-in
    replacement for dictionaries where immutability is desired.
    """

    dict_cls = dict

    def __init__(self, *args, **kwargs):
        self._dict = self.dict_cls(*args, **kwargs)
        self._hash = None

    def __getitem__(self, key):
        return self._dict[key]

    def __contains__(self, key):
        return key in self._dict

    def copy(self, **add_or_replace):
        return self.__class__(self, **add_or_replace)

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        return len(self._dict)

    def __repr__(self):
        return '<%s %r>' % (self.__class__.__name__, self._dict)

    def __hash__(self):
        if self._hash is None:
            h = 0
            for key, value in self._dict.items():
                h ^= hash((key, value))
            self._hash = h
        return self._hash
Martin Reinecke's avatar
Martin Reinecke committed
339
340


Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
341
342
343
def special_add_at(a, axis, index, b):
    if a.dtype != b.dtype:
        raise TypeError("data type mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
344
345
346
347
    sz1 = int(np.prod(a.shape[:axis]))
    sz3 = int(np.prod(a.shape[axis+1:]))
    a2 = a.reshape([sz1, -1, sz3])
    b2 = b.reshape([sz1, -1, sz3])
Martin Reinecke's avatar
Martin Reinecke committed
348
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
349
350
351
352
353
354
        dt2 = a.real.dtype
        a2 = a2.view(dt2)
        b2 = b2.view(dt2)
        sz3 *= 2
    for i1 in range(sz1):
        for i3 in range(sz3):
Martin Reinecke's avatar
Martin Reinecke committed
355
356
            a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
                                         minlength=a2.shape[1])
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
357

Martin Reinecke's avatar
Martin Reinecke committed
358
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
359
360
        a2 = a2.view(a.dtype)
    return a2.reshape(a.shape)
Martin Reinecke's avatar
Martin Reinecke committed
361
362
363
364
365


_iscomplex_tpl = (np.complex64, np.complex128)
def iscomplextype(dtype):
    return dtype.type in _iscomplex_tpl