nifty_utilities.py 7.65 KB
Newer Older
Ultima's avatar
Ultima committed
1
2
3
# -*- coding: utf-8 -*-

import numpy as np
4
5
from itertools import product

6
from nifty.config import about
7

8

9
10
def get_slice_list(shape, axes):
    """
theos's avatar
theos committed
11
12
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
13
14
15
16

    Parameters
    ----------
    shape: tuple
theos's avatar
theos committed
17
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
18
    axes: tuple
theos's avatar
theos committed
19
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
20
21
22
23
24
25
26
27
28
29
30
31

    Yields
    -------
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
    ValueError
        If axes(axis) does not match shape.
32
    """
theos's avatar
theos committed
33

34
35
36
    if not shape:
        raise ValueError(about._errors.cstring("ERROR: shape cannot be None."))

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    if axes:
        if not all(axis < len(shape) for axis in axes):
            raise ValueError(
                about._errors.cstring("ERROR: \
                                      axes(axis) does not match shape.")
            )
        axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
        axes_iterables =\
            [range(y) for x, y in enumerate(shape) if x not in axes]
        for index in product(*axes_iterables):
            it_iter = iter(index)
            slice_list = [
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
            ]
            yield slice_list
    else:
        yield [slice(None, None)]
        return
Ultima's avatar
Ultima committed
56

Ultima's avatar
Ultima committed
57

58
def hermitianize_gaussian(x, axes=None):
Ultima's avatar
Ultima committed
59
    # make the point inversions
60
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
61
    flipped_x = flipped_x.conjugate()
Ultima's avatar
Ultima committed
62
    # check if x was already hermitian
Ultima's avatar
Ultima committed
63
64
    if (x == flipped_x).all():
        return x
Ultima's avatar
Ultima committed
65
66
    # average x and flipped_x.
    # Correct the variance by multiplying sqrt(0.5)
Ultima's avatar
Ultima committed
67
    x = (x + flipped_x) * np.sqrt(0.5)
Ultima's avatar
Ultima committed
68
69
70
    # The fixed points of the point inversion must not be avaraged.
    # Hence one must multiply them again with sqrt(0.5)
    # -> Get the middle index of the array
Ultima's avatar
Ultima committed
71
72
    mid_index = np.array(x.shape, dtype=np.int)//2
    dimensions = mid_index.size
Ultima's avatar
Ultima committed
73
74
    # Use ndindex to iterate over all combinations of zeros and the
    # mid_index in order to correct all fixed points.
75
76
77
78
79
80
    if axes is None:
        axes = xrange(dimensions)

    ndlist = [2 if i in axes else 1 for i in xrange(dimensions)]
    ndlist = tuple(ndlist)
    for i in np.ndindex(ndlist):
Ultima's avatar
Ultima committed
81
82
83
84
85
86
        temp_index = tuple(i*mid_index)
        x[temp_index] *= np.sqrt(0.5)
    try:
        x.hermitian = True
    except(AttributeError):
        pass
87

Ultima's avatar
Ultima committed
88
89
90
    return x


91
def hermitianize(x, axes=None):
Ultima's avatar
Ultima committed
92
    # make the point inversions
93
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
94
95
96
97
98
99
100
101
102
103
104
105
106
    flipped_x = flipped_x.conjugate()
    # check if x was already hermitian
    if (x == flipped_x).all():
        return x
    # average x and flipped_x.
    x = (x + flipped_x) / 2.
    try:
        x.hermitian = True
    except(AttributeError):
        pass

    return x

Ultima's avatar
Ultima committed
107

108
def _hermitianize_inverter(x, axes):
Ultima's avatar
Ultima committed
109
    # calculate the number of dimensions the input array has
Ultima's avatar
Ultima committed
110
    dimensions = len(x.shape)
Ultima's avatar
Ultima committed
111
112
113
    # prepare the slicing object which will be used for mirroring
    slice_primitive = [slice(None), ]*dimensions
    # copy the input data
Ultima's avatar
Ultima committed
114
    y = x.copy()
115
116
117
118
119
120

    if axes is None:
        axes = xrange(dimensions)

    # flip in the desired directions
    for i in axes:
Ultima's avatar
Ultima committed
121
        slice_picker = slice_primitive[:]
Ultima's avatar
Ultima committed
122
        slice_picker[i] = slice(1, None, None)
ultimanet's avatar
ultimanet committed
123
124
        slice_picker = tuple(slice_picker)

Ultima's avatar
Ultima committed
125
126
        slice_inverter = slice_primitive[:]
        slice_inverter[i] = slice(None, 0, -1)
ultimanet's avatar
ultimanet committed
127
        slice_inverter = tuple(slice_inverter)
Ultima's avatar
Ultima committed
128
129

        try:
130
131
            y.set_data(to_key=slice_picker, data=y,
                       from_key=slice_inverter)
Ultima's avatar
Ultima committed
132
133
134
        except(AttributeError):
            y[slice_picker] = y[slice_inverter]
    return y
135
136


137
def direct_vdot(x, y):
Ultima's avatar
Ultima committed
138
    # the input could be fields. Try to extract the data
Ultima's avatar
Ultima committed
139
140
141
142
143
144
145
146
    try:
        x = x.get_val()
    except(AttributeError):
        pass
    try:
        y = y.get_val()
    except(AttributeError):
        pass
Ultima's avatar
Ultima committed
147
    # try to make a direct vdot
Ultima's avatar
Ultima committed
148
149
150
151
    try:
        return x.vdot(y)
    except(AttributeError):
        pass
152

Ultima's avatar
Ultima committed
153
154
155
    try:
        return y.vdot(x)
    except(AttributeError):
156
157
        pass

Ultima's avatar
Ultima committed
158
    # fallback to numpy
159
160
    return np.vdot(x, y)

Ultima's avatar
Ultima committed
161
162

def convert_nested_list_to_object_array(x):
Ultima's avatar
Ultima committed
163
164
165
    # if x is a nested_list full of ndarrays all having the same size,
    # np.shape returns the shape of the ndarrays, too, i.e. too many
    # dimensions
Ultima's avatar
Ultima committed
166
    possible_shape = np.shape(x)
Ultima's avatar
Ultima committed
167
    # Check if possible_shape goes too deep.
168
    dimension_counter = 0
Ultima's avatar
Ultima committed
169
170
    current_extract = x
    for i in xrange(len(possible_shape)):
171
172
        if not isinstance(current_extract, list) and \
                not isinstance(current_extract, tuple):
Ultima's avatar
Ultima committed
173
174
175
176
            break
        current_extract = current_extract[0]
        dimension_counter += 1
    real_shape = possible_shape[:dimension_counter]
Ultima's avatar
Ultima committed
177
    # if the numpy array was not encapsulated at all, return x directly
Ultima's avatar
Ultima committed
178
179
    if real_shape == ():
        return x
Ultima's avatar
Ultima committed
180
181
    # Prepare the carrier-object
    carrier = np.empty(real_shape, dtype=np.object)
182
    for i in xrange(reduce(lambda x, y: x * y, real_shape)):
Ultima's avatar
Ultima committed
183
        ii = np.unravel_index(i, real_shape)
184
        try:
Ultima's avatar
Ultima committed
185
186
187
            carrier[ii] = x[ii]
        except(TypeError):
            extracted = x
188
            for j in xrange(len(ii)):
Ultima's avatar
Ultima committed
189
                extracted = extracted[ii[j]]
190
            carrier[ii] = extracted
Ultima's avatar
Ultima committed
191
192
193
194
195
196
197
198
199
    return carrier


def field_map(ishape, function, *args):
    if ishape == ():
        return function(*args)
    else:
        if args == ():
            result = np.empty(ishape, dtype=np.object)
200
            for i in xrange(reduce(lambda x, y: x * y, ishape)):
Ultima's avatar
Ultima committed
201
202
203
204
                ii = np.unravel_index(i, ishape)
                result[ii] = function()
            return result
        else:
Ultima's avatar
Ultima committed
205
206
207
208
            # define a helper function in order to clip the get-indices
            # to be suitable for the foreign arrays in args.
            # This allows you to do operations, like adding to fields
            # with ishape (3,4,3) and (3,4,1)
Ultima's avatar
Ultima committed
209
210
211
212
213
            def get_clipped(w, ind):
                w_shape = np.array(np.shape(w))
                get_tuple = tuple(np.clip(ind, 0, w_shape-1))
                return w[get_tuple]
            result = np.empty_like(args[0])
214
            for i in xrange(reduce(lambda x, y: x * y, result.shape)):
Ultima's avatar
Ultima committed
215
                ii = np.unravel_index(i, result.shape)
216
217
218
219
220
                result[ii] = function(
                    *map(
                        lambda z: get_clipped(z, ii), args
                    )
                )
Ultima's avatar
Ultima committed
221
222
                # result[ii] = function(*map(lambda z: z[ii], args))
            return result
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


def cast_axis_to_tuple(axis, length):
    if axis is None:
        return None
    try:
        axis = tuple(int(item) for item in axis)
    except(TypeError):
        if np.isscalar(axis):
            axis = (int(axis), )
        else:
            raise TypeError(
                "ERROR: Could not convert axis-input to tuple of ints")

    # shift negative indices to positive ones
    axis = tuple(item if (item >= 0) else (item + length) for item in axis)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    return axis


def complex_bincount(x, weights=None, minlength=None):
    try:
        complex_weights_Q = issubclass(weights.dtype.type,
                                       np.complexfloating)
    except AttributeError:
        complex_weights_Q = False

    if complex_weights_Q:
        real_bincount = x.bincount(weights=weights.real,
                                   minlength=minlength)
        imag_bincount = x.bincount(weights=weights.imag,
                                   minlength=minlength)
        return real_bincount + imag_bincount
    else:
        return x.bincount(weights=weights, minlength=minlength)