nifty_utilities.py 7.49 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Theo Steininger's avatar
Theo Steininger committed
18

Martin Reinecke's avatar
Martin Reinecke committed
19 20
from builtins import next
from builtins import range
Theo Steininger's avatar
Theo Steininger committed
21
import numpy as np
22
from itertools import product
Martin Reinecke's avatar
Martin Reinecke committed
23
import itertools
Martin Reinecke's avatar
Martin Reinecke committed
24
from functools import reduce
25

26 27
def get_slice_list(shape, axes):
    """
Theo Steininger's avatar
Theo Steininger committed
28 29
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
30 31 32 33

    Parameters
    ----------
    shape: tuple
Theo Steininger's avatar
Theo Steininger committed
34
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
35
    axes: tuple
Theo Steininger's avatar
Theo Steininger committed
36
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
37 38 39 40 41 42 43 44 45 46 47 48

    Yields
    -------
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
    ValueError
        If axes(axis) does not match shape.
49
    """
Theo Steininger's avatar
Theo Steininger committed
50

51
    if not shape:
52
        raise ValueError("shape cannot be None.")
53

54 55
    if axes:
        if not all(axis < len(shape) for axis in axes):
56
            raise ValueError("axes(axis) does not match shape.")
57
        axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
Jait Dixit's avatar
Jait Dixit committed
58
        axes_iterables = \
Martin Reinecke's avatar
Martin Reinecke committed
59
            [list(range(y)) for x, y in enumerate(shape) if x not in axes]
60 61 62 63 64
        for index in product(*axes_iterables):
            it_iter = iter(index)
            slice_list = [
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
Jait Dixit's avatar
Jait Dixit committed
65
                ]
66 67 68 69
            yield slice_list
    else:
        yield [slice(None, None)]
        return
Theo Steininger's avatar
Theo Steininger committed
70

Theo Steininger's avatar
Theo Steininger committed
71

72
def cast_axis_to_tuple(axis, length=None):
73 74 75 76 77 78
    if axis is None:
        return None
    try:
        axis = tuple(int(item) for item in axis)
    except(TypeError):
        if np.isscalar(axis):
Jait Dixit's avatar
Jait Dixit committed
79
            axis = (int(axis),)
80 81
        else:
            raise TypeError(
82
                "Could not convert axis-input to tuple of ints")
83

84 85 86
    if length is not None:
        # shift negative indices to positive ones
        axis = tuple(item if (item >= 0) else (item + length) for item in axis)
87

88 89 90
        # Deactivated this, in order to allow for the ComposedOperator
        # remove duplicate entries
        # axis = tuple(set(axis))
91

92 93 94
        # assert that all entries are elements in [0, length]
        for elem in axis:
            assert (0 <= elem < length)
95

96 97 98
    return axis


99
def parse_domain(domain):
100
    from .domain_object import DomainObject
101 102
    if domain is None:
        domain = ()
103
    elif isinstance(domain, DomainObject):
104 105 106 107 108
        domain = (domain,)
    elif not isinstance(domain, tuple):
        domain = tuple(domain)

    for d in domain:
109
        if not isinstance(d, DomainObject):
110
            raise TypeError(
Martin Reinecke's avatar
Martin Reinecke committed
111
                "Given object contains something that is not an "
112
                "instance of DomainObject-class.")
113
    return domain
Martin Reinecke's avatar
Martin Reinecke committed
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
def slicing_generator(shape, axes):
    """
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.

    Parameters
    ----------
    shape: tuple
        Shape of the data array to traverse over.
    axes: tuple
        Axes which should not be iterated over.

    Yields
    -------
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
    ValueError
        If axes(axis) does not match shape.
    """

    if not shape:
        raise ValueError("ERROR: shape cannot be None.")

    if axes:
        if not all(axis < len(shape) for axis in axes):
            raise ValueError("ERROR: axes(axis) does not match shape.")
        axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
        axes_iterables =\
            [list(range(y)) for x, y in enumerate(shape) if x not in axes]
        for current_index in itertools.product(*axes_iterables):
            it_iter = iter(current_index)
            slice_list = [next(it_iter) if use_axis else
                          slice(None, None) for use_axis in axes_select]
            yield slice_list
    else:
        yield [slice(None, None)]
        return

def bincount_axis(obj, minlength=None, weights=None, axis=None):
    if minlength is not None:
        length = max(np.amax(obj) + 1, minlength)
    else:
        length = np.amax(obj) + 1

    if obj.shape == ():
        raise ValueError("object of too small depth for desired array")
    data = obj

    # if present, parse the axis keyword and transpose/reorder self.data
    # such that all affected axes follow each other. Only if they are in a
    # sequence flattening will be possible
    if axis is not None:
        # do the reordering
        ndim = len(obj.shape)
        axis = sorted(cast_axis_to_tuple(axis, length=ndim))
        reordering = [x for x in range(ndim) if x not in axis]
        reordering += axis

        data = np.transpose(data, reordering)
        if weights is not None:
            weights = np.transpose(weights, reordering)

        reord_axis = list(range(ndim-len(axis), ndim))

        # semi-flatten the dimensions in `axis`, i.e. after reordering
        # the last ones.
        semi_flat_dim = reduce(lambda x, y: x*y,
                               data.shape[ndim-len(reord_axis):])
        flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
    else:
        flat_shape = (reduce(lambda x, y: x*y, data.shape), )

    data = np.ascontiguousarray(data.reshape(flat_shape))
    if weights is not None:
        weights = np.ascontiguousarray(
                            weights.reshape(flat_shape))

    # compute the local bincount results
    # -> prepare the local result array
    if weights is None:
        result_dtype = np.int
    else:
        result_dtype = np.float
    local_counts = np.empty(flat_shape[:-1] + (length, ),
                            dtype=result_dtype)
    # iterate over all entries in the surviving axes and compute the local
    # bincounts
    for slice_list in slicing_generator(flat_shape,
                                        axes=(len(flat_shape)-1, )):
        if weights is not None:
            current_weights = weights[slice_list]
        else:
            current_weights = None
        local_counts[slice_list] = np.bincount(
                                        data[slice_list],
                                        weights=current_weights,
                                        minlength=length)

    # restore the original ordering
    # place the bincount stuff at the location of the first `axis` entry
    if axis is not None:
        # axis has been sorted above
        insert_position = axis[0]
        new_ndim = len(local_counts.shape)
        return_order = (list(range(0, insert_position)) +
                        [new_ndim-1, ] +
                        list(range(insert_position, new_ndim-1)))
        local_counts = np.ascontiguousarray(
                            local_counts.transpose(return_order))
    return local_counts