utilities.py 7.41 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Theo Steininger's avatar
Theo Steininger committed
17

18
import collections
19
from itertools import product
Martin Reinecke's avatar
Martin Reinecke committed
20
from functools import reduce
21 22 23

import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
24
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
Martin Reinecke's avatar
Martin Reinecke committed
25
           "memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
26
           "my_lincomb", "indent",
Martin Reinecke's avatar
Martin Reinecke committed
27
           "my_product", "frozendict", "special_add_at", "iscomplextype"]
28 29


Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
30 31
def my_sum(iterable):
    return reduce(lambda x, y: x+y, iterable)
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


def my_lincomb_simple(terms, factors):
    terms2 = map(lambda v: v[0]*v[1], zip(terms, factors))
    return my_sum(terms2)


def my_lincomb(terms, factors):
    terms2 = map(lambda v: v[0] if v[1] == 1. else v[0]*v[1],
                 zip(terms, factors))
    return my_sum(terms2)


def my_product(iterable):
    return reduce(lambda x, y: x*y, iterable)
Martin Reinecke's avatar
Martin Reinecke committed
47

48

49 50
def get_slice_list(shape, axes):
    """
Theo Steininger's avatar
Theo Steininger committed
51 52
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
53 54 55 56

    Parameters
    ----------
    shape: tuple
Theo Steininger's avatar
Theo Steininger committed
57
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
58
    axes: tuple
Theo Steininger's avatar
Theo Steininger committed
59
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
60

Martin Reinecke's avatar
Martin Reinecke committed
61 62
    Yields
    ------
Jait Dixit's avatar
Jait Dixit committed
63 64 65 66 67 68 69 70
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
        If axes(axis) does not match shape.
71
    """
Martin Reinecke's avatar
Martin Reinecke committed
72
    if shape is None:
73
        raise ValueError("shape cannot be None.")
74

75 76
    if axes:
        if not all(axis < len(shape) for axis in axes):
77
            raise ValueError("axes(axis) does not match shape.")
Martin Reinecke's avatar
Martin Reinecke committed
78
        axes_select = [0 if x in axes else 1 for x in range(len(shape))]
Jait Dixit's avatar
Jait Dixit committed
79
        axes_iterables = \
Martin Reinecke's avatar
Martin Reinecke committed
80
            [list(range(y)) for x, y in enumerate(shape) if x not in axes]
81 82
        for index in product(*axes_iterables):
            it_iter = iter(index)
83
            slice_list = tuple(
84 85
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
86
            )
87 88 89
            yield slice_list
    else:
        yield [slice(None, None)]
Theo Steininger's avatar
Theo Steininger committed
90

Theo Steininger's avatar
Theo Steininger committed
91

92 93 94 95 96 97 98
def safe_cast(tfunc, val):
    tmp = tfunc(val)
    if val != tmp:
        raise ValueError("value changed during cast")
    return tmp


Martin Reinecke's avatar
Martin Reinecke committed
99 100
def parse_spaces(spaces, nspc):
    nspc = safe_cast(int, nspc)
101
    if spaces is None:
Martin Reinecke's avatar
Martin Reinecke committed
102
        return tuple(range(nspc))
103 104 105 106
    elif np.isscalar(spaces):
        spaces = (safe_cast(int, spaces),)
    else:
        spaces = tuple(safe_cast(int, item) for item in spaces)
107 108
    if len(spaces) == 0:
        return spaces
109
    tmp = tuple(set(spaces))
Martin Reinecke's avatar
Martin Reinecke committed
110
    if tmp[0] < 0 or tmp[-1] >= nspc:
111 112 113 114
        raise ValueError("space index out of range")
    if len(tmp) != len(spaces):
        raise ValueError("multiply defined space indices")
    return spaces
Martin Reinecke's avatar
Martin Reinecke committed
115 116


117 118 119
def infer_space(domain, space):
    if space is None:
        if len(domain) != 1:
120 121
            raise ValueError("'space' index must be given for objects based on"
                             " DomainTuples containing more than one domain")
122 123 124 125 126 127 128
        space = 0
    space = int(space)
    if space < 0 or space >= len(domain):
        raise ValueError("space index out of range")
    return space


Martin Reinecke's avatar
Martin Reinecke committed
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
def memo(f):
    name = f.__name__

    def wrapped_f(self):
        if not hasattr(self, "_cache"):
            self._cache = {}
        try:
            return self._cache[name]
        except KeyError:
            self._cache[name] = f(self)
            return self._cache[name]
    return wrapped_f


class _DocStringInheritor(type):
    """
    A variation on
Martin Reinecke's avatar
Martin Reinecke committed
146
    https://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
Martin Reinecke's avatar
Martin Reinecke committed
147 148 149 150 151 152 153 154 155 156
    by Paul McGuire
    """
    def __new__(meta, name, bases, clsdict):
        if not('__doc__' in clsdict and clsdict['__doc__']):
            for mro_cls in (mro_cls for base in bases
                            for mro_cls in base.mro()):
                doc = mro_cls.__doc__
                if doc:
                    clsdict['__doc__'] = doc
                    break
Martin Reinecke's avatar
Martin Reinecke committed
157
        for attr, attribute in clsdict.items():
Martin Reinecke's avatar
Martin Reinecke committed
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
            if not attribute.__doc__:
                for mro_cls in (mro_cls for base in bases
                                for mro_cls in base.mro()
                                if hasattr(mro_cls, attr)):
                    doc = getattr(getattr(mro_cls, attr), '__doc__')
                    if doc:
                        if isinstance(attribute, property):
                            clsdict[attr] = property(attribute.fget,
                                                     attribute.fset,
                                                     attribute.fdel,
                                                     doc)
                        else:
                            attribute.__doc__ = doc
                        break
        return super(_DocStringInheritor, meta).__new__(meta, name,
                                                        bases, clsdict)


Martin Reinecke's avatar
Martin Reinecke committed
176
class NiftyMeta(_DocStringInheritor):
Martin Reinecke's avatar
Martin Reinecke committed
177
    pass
Martin Reinecke's avatar
Martin Reinecke committed
178 179


Martin Reinecke's avatar
Martin Reinecke committed
180
class frozendict(collections.abc.Mapping):
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
    """
    An immutable wrapper around dictionaries that implements the complete
    :py:class:`collections.Mapping` interface. It can be used as a drop-in
    replacement for dictionaries where immutability is desired.
    """

    dict_cls = dict

    def __init__(self, *args, **kwargs):
        self._dict = self.dict_cls(*args, **kwargs)
        self._hash = None

    def __getitem__(self, key):
        return self._dict[key]

    def __contains__(self, key):
        return key in self._dict

    def copy(self, **add_or_replace):
        return self.__class__(self, **add_or_replace)

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        return len(self._dict)

    def __repr__(self):
Martin Reinecke's avatar
Martin Reinecke committed
209
        return '<{} {}>'.format(self.__class__.__name__, self._dict)
210 211 212 213 214 215 216 217

    def __hash__(self):
        if self._hash is None:
            h = 0
            for key, value in self._dict.items():
                h ^= hash((key, value))
            self._hash = h
        return self._hash
Martin Reinecke's avatar
Martin Reinecke committed
218 219


Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
220 221 222
def special_add_at(a, axis, index, b):
    if a.dtype != b.dtype:
        raise TypeError("data type mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
223 224 225 226
    sz1 = int(np.prod(a.shape[:axis]))
    sz3 = int(np.prod(a.shape[axis+1:]))
    a2 = a.reshape([sz1, -1, sz3])
    b2 = b.reshape([sz1, -1, sz3])
Martin Reinecke's avatar
Martin Reinecke committed
227
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
228 229 230 231 232 233
        dt2 = a.real.dtype
        a2 = a2.view(dt2)
        b2 = b2.view(dt2)
        sz3 *= 2
    for i1 in range(sz1):
        for i3 in range(sz3):
Martin Reinecke's avatar
Martin Reinecke committed
234 235
            a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
                                         minlength=a2.shape[1])
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
236

Martin Reinecke's avatar
Martin Reinecke committed
237
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
238 239
        a2 = a2.view(a.dtype)
    return a2.reshape(a.shape)
Martin Reinecke's avatar
Martin Reinecke committed
240 241 242


_iscomplex_tpl = (np.complex64, np.complex128)
243 244


Martin Reinecke's avatar
Martin Reinecke committed
245 246
def iscomplextype(dtype):
    return dtype.type in _iscomplex_tpl
247 248 249 250


def indent(inp):
    return "\n".join((("  "+s).rstrip() for s in inp.splitlines()))