utilities.py 7.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Ultima's avatar
Ultima committed
18

19
import collections
20
21
22
from itertools import product

import numpy as np
Martin Reinecke's avatar
fix    
Martin Reinecke committed
23
from future.utils import with_metaclass
24
25

from .compat import *
26

Martin Reinecke's avatar
Martin Reinecke committed
27
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
28
29
           "memo", "NiftyMetaBase", "my_sum", "my_lincomb_simple",
           "my_lincomb", "indent",
Martin Reinecke's avatar
Martin Reinecke committed
30
           "my_product", "frozendict", "special_add_at", "iscomplextype"]
31
32


Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
33
34
def my_sum(iterable):
    return reduce(lambda x, y: x+y, iterable)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


def my_lincomb_simple(terms, factors):
    terms2 = map(lambda v: v[0]*v[1], zip(terms, factors))
    return my_sum(terms2)


def my_lincomb(terms, factors):
    terms2 = map(lambda v: v[0] if v[1] == 1. else v[0]*v[1],
                 zip(terms, factors))
    return my_sum(terms2)


def my_product(iterable):
    return reduce(lambda x, y: x*y, iterable)
Martin Reinecke's avatar
Martin Reinecke committed
50

51

52
53
def get_slice_list(shape, axes):
    """
theos's avatar
theos committed
54
55
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
56
57
58
59

    Parameters
    ----------
    shape: tuple
theos's avatar
theos committed
60
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
61
    axes: tuple
theos's avatar
theos committed
62
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
63

Martin Reinecke's avatar
Martin Reinecke committed
64
65
    Yields
    ------
Jait Dixit's avatar
Jait Dixit committed
66
67
68
69
70
71
72
73
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
        If axes(axis) does not match shape.
74
    """
Martin Reinecke's avatar
Martin Reinecke committed
75
    if shape is None:
76
        raise ValueError("shape cannot be None.")
77

78
79
    if axes:
        if not all(axis < len(shape) for axis in axes):
80
            raise ValueError("axes(axis) does not match shape.")
Martin Reinecke's avatar
Martin Reinecke committed
81
        axes_select = [0 if x in axes else 1 for x in range(len(shape))]
Jait Dixit's avatar
Jait Dixit committed
82
        axes_iterables = \
Martin Reinecke's avatar
Martin Reinecke committed
83
            [list(range(y)) for x, y in enumerate(shape) if x not in axes]
84
85
        for index in product(*axes_iterables):
            it_iter = iter(index)
86
            slice_list = tuple(
87
88
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
89
            )
90
91
92
            yield slice_list
    else:
        yield [slice(None, None)]
Ultima's avatar
Ultima committed
93

Ultima's avatar
Ultima committed
94

95
96
97
98
99
100
101
def safe_cast(tfunc, val):
    tmp = tfunc(val)
    if val != tmp:
        raise ValueError("value changed during cast")
    return tmp


Martin Reinecke's avatar
Martin Reinecke committed
102
103
def parse_spaces(spaces, nspc):
    nspc = safe_cast(int, nspc)
104
    if spaces is None:
Martin Reinecke's avatar
Martin Reinecke committed
105
        return tuple(range(nspc))
106
107
108
109
    elif np.isscalar(spaces):
        spaces = (safe_cast(int, spaces),)
    else:
        spaces = tuple(safe_cast(int, item) for item in spaces)
110
111
    if len(spaces) == 0:
        return spaces
112
    tmp = tuple(set(spaces))
Martin Reinecke's avatar
Martin Reinecke committed
113
    if tmp[0] < 0 or tmp[-1] >= nspc:
114
115
116
117
        raise ValueError("space index out of range")
    if len(tmp) != len(spaces):
        raise ValueError("multiply defined space indices")
    return spaces
Martin Reinecke's avatar
Martin Reinecke committed
118
119


120
121
122
def infer_space(domain, space):
    if space is None:
        if len(domain) != 1:
123
124
            raise ValueError("'space' index must be given for objects based on"
                             " DomainTuples containing more than one domain")
125
126
127
128
129
130
131
        space = 0
    space = int(space)
    if space < 0 or space >= len(domain):
        raise ValueError("space index out of range")
    return space


Martin Reinecke's avatar
Martin Reinecke committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def memo(f):
    name = f.__name__

    def wrapped_f(self):
        if not hasattr(self, "_cache"):
            self._cache = {}
        try:
            return self._cache[name]
        except KeyError:
            self._cache[name] = f(self)
            return self._cache[name]
    return wrapped_f


class _DocStringInheritor(type):
    """
    A variation on
    http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
    by Paul McGuire
    """
    def __new__(meta, name, bases, clsdict):
        if not('__doc__' in clsdict and clsdict['__doc__']):
            for mro_cls in (mro_cls for base in bases
                            for mro_cls in base.mro()):
                doc = mro_cls.__doc__
                if doc:
                    clsdict['__doc__'] = doc
                    break
Martin Reinecke's avatar
Martin Reinecke committed
160
        for attr, attribute in clsdict.items():
Martin Reinecke's avatar
Martin Reinecke committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
            if not attribute.__doc__:
                for mro_cls in (mro_cls for base in bases
                                for mro_cls in base.mro()
                                if hasattr(mro_cls, attr)):
                    doc = getattr(getattr(mro_cls, attr), '__doc__')
                    if doc:
                        if isinstance(attribute, property):
                            clsdict[attr] = property(attribute.fget,
                                                     attribute.fset,
                                                     attribute.fdel,
                                                     doc)
                        else:
                            attribute.__doc__ = doc
                        break
        return super(_DocStringInheritor, meta).__new__(meta, name,
                                                        bases, clsdict)


Martin Reinecke's avatar
Martin Reinecke committed
179
class NiftyMeta(_DocStringInheritor):
Martin Reinecke's avatar
Martin Reinecke committed
180
    pass
Martin Reinecke's avatar
Martin Reinecke committed
181
182


Martin Reinecke's avatar
Martin Reinecke committed
183
184
185
186
def NiftyMetaBase():
    return with_metaclass(NiftyMeta, type('NewBase', (object,), {}))


187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class frozendict(collections.Mapping):
    """
    An immutable wrapper around dictionaries that implements the complete
    :py:class:`collections.Mapping` interface. It can be used as a drop-in
    replacement for dictionaries where immutability is desired.
    """

    dict_cls = dict

    def __init__(self, *args, **kwargs):
        self._dict = self.dict_cls(*args, **kwargs)
        self._hash = None

    def __getitem__(self, key):
        return self._dict[key]

    def __contains__(self, key):
        return key in self._dict

    def copy(self, **add_or_replace):
        return self.__class__(self, **add_or_replace)

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        return len(self._dict)

    def __repr__(self):
Martin Reinecke's avatar
Martin Reinecke committed
216
        return '<{} {}>'.format(self.__class__.__name__, self._dict)
217
218
219
220
221
222
223
224

    def __hash__(self):
        if self._hash is None:
            h = 0
            for key, value in self._dict.items():
                h ^= hash((key, value))
            self._hash = h
        return self._hash
Martin Reinecke's avatar
Martin Reinecke committed
225
226


Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
227
228
229
def special_add_at(a, axis, index, b):
    if a.dtype != b.dtype:
        raise TypeError("data type mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
230
231
232
233
    sz1 = int(np.prod(a.shape[:axis]))
    sz3 = int(np.prod(a.shape[axis+1:]))
    a2 = a.reshape([sz1, -1, sz3])
    b2 = b.reshape([sz1, -1, sz3])
Martin Reinecke's avatar
Martin Reinecke committed
234
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
235
236
237
238
239
240
        dt2 = a.real.dtype
        a2 = a2.view(dt2)
        b2 = b2.view(dt2)
        sz3 *= 2
    for i1 in range(sz1):
        for i3 in range(sz3):
Martin Reinecke's avatar
Martin Reinecke committed
241
242
            a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
                                         minlength=a2.shape[1])
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
243

Martin Reinecke's avatar
Martin Reinecke committed
244
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
245
246
        a2 = a2.view(a.dtype)
    return a2.reshape(a.shape)
Martin Reinecke's avatar
Martin Reinecke committed
247
248
249
250
251


_iscomplex_tpl = (np.complex64, np.complex128)
def iscomplextype(dtype):
    return dtype.type in _iscomplex_tpl
252
253
254
255


def indent(inp):
    return "\n".join((("  "+s).rstrip() for s in inp.splitlines()))