nifty_utilities.py 8.44 KB
Newer Older
Ultima's avatar
Ultima committed
1
2
3
# -*- coding: utf-8 -*-

import numpy as np
4
5
from itertools import product

6

7
8
def get_slice_list(shape, axes):
    """
theos's avatar
theos committed
9
10
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
11
12
13
14

    Parameters
    ----------
    shape: tuple
theos's avatar
theos committed
15
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
16
    axes: tuple
theos's avatar
theos committed
17
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
18
19
20
21
22
23
24
25
26
27
28
29

    Yields
    -------
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
    ValueError
        If axes(axis) does not match shape.
30
    """
theos's avatar
theos committed
31

32
    if not shape:
33
        raise ValueError("shape cannot be None.")
34

35
36
    if axes:
        if not all(axis < len(shape) for axis in axes):
37
            raise ValueError("axes(axis) does not match shape.")
38
        axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
Jait Dixit's avatar
Jait Dixit committed
39
        axes_iterables = \
40
41
42
43
44
45
            [range(y) for x, y in enumerate(shape) if x not in axes]
        for index in product(*axes_iterables):
            it_iter = iter(index)
            slice_list = [
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
Jait Dixit's avatar
Jait Dixit committed
46
                ]
47
48
49
50
            yield slice_list
    else:
        yield [slice(None, None)]
        return
Ultima's avatar
Ultima committed
51

Ultima's avatar
Ultima committed
52

53
def hermitianize_gaussian(x, axes=None):
Ultima's avatar
Ultima committed
54
    # make the point inversions
55
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
56
    flipped_x = flipped_x.conjugate()
Ultima's avatar
Ultima committed
57
    # check if x was already hermitian
Ultima's avatar
Ultima committed
58
59
    if (x == flipped_x).all():
        return x
Ultima's avatar
Ultima committed
60
61
    # average x and flipped_x.
    # Correct the variance by multiplying sqrt(0.5)
Ultima's avatar
Ultima committed
62
    x = (x + flipped_x) * np.sqrt(0.5)
Ultima's avatar
Ultima committed
63
64
65
    # The fixed points of the point inversion must not be avaraged.
    # Hence one must multiply them again with sqrt(0.5)
    # -> Get the middle index of the array
Jait Dixit's avatar
Jait Dixit committed
66
    mid_index = np.array(x.shape, dtype=np.int) // 2
Ultima's avatar
Ultima committed
67
    dimensions = mid_index.size
Ultima's avatar
Ultima committed
68
69
    # Use ndindex to iterate over all combinations of zeros and the
    # mid_index in order to correct all fixed points.
70
71
72
73
74
75
    if axes is None:
        axes = xrange(dimensions)

    ndlist = [2 if i in axes else 1 for i in xrange(dimensions)]
    ndlist = tuple(ndlist)
    for i in np.ndindex(ndlist):
Jait Dixit's avatar
Jait Dixit committed
76
        temp_index = tuple(i * mid_index)
Ultima's avatar
Ultima committed
77
78
79
80
81
        x[temp_index] *= np.sqrt(0.5)
    try:
        x.hermitian = True
    except(AttributeError):
        pass
82

Ultima's avatar
Ultima committed
83
84
85
    return x


86
def hermitianize(x, axes=None):
Ultima's avatar
Ultima committed
87
    # make the point inversions
88
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
89
    flipped_x = flipped_x.conjugate()
90

Ultima's avatar
Ultima committed
91
    # average x and flipped_x.
92
93
94
    # x = (x + flipped_x) / 2.
    result_x = x + flipped_x
    result_x /= 2.
Ultima's avatar
Ultima committed
95

96
97
98
99
100
101
#    try:
#        x.hermitian = True
#    except(AttributeError):
#        pass

    return result_x
Ultima's avatar
Ultima committed
102

Ultima's avatar
Ultima committed
103

104
def _hermitianize_inverter(x, axes):
Ultima's avatar
Ultima committed
105
    # calculate the number of dimensions the input array has
Ultima's avatar
Ultima committed
106
    dimensions = len(x.shape)
Ultima's avatar
Ultima committed
107
    # prepare the slicing object which will be used for mirroring
Jait Dixit's avatar
Jait Dixit committed
108
    slice_primitive = [slice(None), ] * dimensions
Ultima's avatar
Ultima committed
109
    # copy the input data
Ultima's avatar
Ultima committed
110
    y = x.copy()
111
112
113
114
115
116

    if axes is None:
        axes = xrange(dimensions)

    # flip in the desired directions
    for i in axes:
Ultima's avatar
Ultima committed
117
        slice_picker = slice_primitive[:]
Ultima's avatar
Ultima committed
118
        slice_picker[i] = slice(1, None, None)
ultimanet's avatar
ultimanet committed
119
120
        slice_picker = tuple(slice_picker)

Ultima's avatar
Ultima committed
121
122
        slice_inverter = slice_primitive[:]
        slice_inverter[i] = slice(None, 0, -1)
ultimanet's avatar
ultimanet committed
123
        slice_inverter = tuple(slice_inverter)
Ultima's avatar
Ultima committed
124
125

        try:
126
127
            y.set_data(to_key=slice_picker, data=y,
                       from_key=slice_inverter)
Ultima's avatar
Ultima committed
128
129
130
        except(AttributeError):
            y[slice_picker] = y[slice_inverter]
    return y
131
132


133
def direct_vdot(x, y):
Ultima's avatar
Ultima committed
134
    # the input could be fields. Try to extract the data
Ultima's avatar
Ultima committed
135
136
137
138
139
140
141
142
    try:
        x = x.get_val()
    except(AttributeError):
        pass
    try:
        y = y.get_val()
    except(AttributeError):
        pass
Ultima's avatar
Ultima committed
143
    # try to make a direct vdot
Ultima's avatar
Ultima committed
144
145
146
147
    try:
        return x.vdot(y)
    except(AttributeError):
        pass
148

Ultima's avatar
Ultima committed
149
150
151
    try:
        return y.vdot(x)
    except(AttributeError):
152
153
        pass

Ultima's avatar
Ultima committed
154
    # fallback to numpy
155
156
    return np.vdot(x, y)

Ultima's avatar
Ultima committed
157
158

def convert_nested_list_to_object_array(x):
Ultima's avatar
Ultima committed
159
160
161
    # if x is a nested_list full of ndarrays all having the same size,
    # np.shape returns the shape of the ndarrays, too, i.e. too many
    # dimensions
Ultima's avatar
Ultima committed
162
    possible_shape = np.shape(x)
Ultima's avatar
Ultima committed
163
    # Check if possible_shape goes too deep.
164
    dimension_counter = 0
Ultima's avatar
Ultima committed
165
166
    current_extract = x
    for i in xrange(len(possible_shape)):
167
168
        if not isinstance(current_extract, list) and \
                not isinstance(current_extract, tuple):
Ultima's avatar
Ultima committed
169
170
171
172
            break
        current_extract = current_extract[0]
        dimension_counter += 1
    real_shape = possible_shape[:dimension_counter]
Ultima's avatar
Ultima committed
173
    # if the numpy array was not encapsulated at all, return x directly
Ultima's avatar
Ultima committed
174
175
    if real_shape == ():
        return x
Ultima's avatar
Ultima committed
176
177
    # Prepare the carrier-object
    carrier = np.empty(real_shape, dtype=np.object)
178
    for i in xrange(reduce(lambda x, y: x * y, real_shape)):
Ultima's avatar
Ultima committed
179
        ii = np.unravel_index(i, real_shape)
180
        try:
Ultima's avatar
Ultima committed
181
182
183
            carrier[ii] = x[ii]
        except(TypeError):
            extracted = x
184
            for j in xrange(len(ii)):
Ultima's avatar
Ultima committed
185
                extracted = extracted[ii[j]]
186
            carrier[ii] = extracted
Ultima's avatar
Ultima committed
187
188
189
190
191
192
193
194
195
    return carrier


def field_map(ishape, function, *args):
    if ishape == ():
        return function(*args)
    else:
        if args == ():
            result = np.empty(ishape, dtype=np.object)
196
            for i in xrange(reduce(lambda x, y: x * y, ishape)):
Ultima's avatar
Ultima committed
197
198
199
200
                ii = np.unravel_index(i, ishape)
                result[ii] = function()
            return result
        else:
Ultima's avatar
Ultima committed
201
202
203
204
            # define a helper function in order to clip the get-indices
            # to be suitable for the foreign arrays in args.
            # This allows you to do operations, like adding to fields
            # with ishape (3,4,3) and (3,4,1)
Ultima's avatar
Ultima committed
205
206
            def get_clipped(w, ind):
                w_shape = np.array(np.shape(w))
Jait Dixit's avatar
Jait Dixit committed
207
                get_tuple = tuple(np.clip(ind, 0, w_shape - 1))
Ultima's avatar
Ultima committed
208
                return w[get_tuple]
Jait Dixit's avatar
Jait Dixit committed
209

Ultima's avatar
Ultima committed
210
            result = np.empty_like(args[0])
211
            for i in xrange(reduce(lambda x, y: x * y, result.shape)):
Ultima's avatar
Ultima committed
212
                ii = np.unravel_index(i, result.shape)
213
214
215
216
217
                result[ii] = function(
                    *map(
                        lambda z: get_clipped(z, ii), args
                    )
                )
Ultima's avatar
Ultima committed
218
219
                # result[ii] = function(*map(lambda z: z[ii], args))
            return result
220
221
222
223
224
225
226
227
228


def cast_axis_to_tuple(axis, length):
    if axis is None:
        return None
    try:
        axis = tuple(int(item) for item in axis)
    except(TypeError):
        if np.isscalar(axis):
Jait Dixit's avatar
Jait Dixit committed
229
            axis = (int(axis),)
230
231
        else:
            raise TypeError(
232
                "Could not convert axis-input to tuple of ints")
233
234
235

    # shift negative indices to positive ones
    axis = tuple(item if (item >= 0) else (item + length) for item in axis)
236

237
    # Deactivated this, in order to allow for the ComposedOperator
238
    # remove duplicate entries
239
    # axis = tuple(set(axis))
240
241
242

    # assert that all entries are elements in [0, length]
    for elem in axis:
Jait Dixit's avatar
Jait Dixit committed
243
        assert (0 <= elem < length)
244

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    return axis


def complex_bincount(x, weights=None, minlength=None):
    try:
        complex_weights_Q = issubclass(weights.dtype.type,
                                       np.complexfloating)
    except AttributeError:
        complex_weights_Q = False

    if complex_weights_Q:
        real_bincount = x.bincount(weights=weights.real,
                                   minlength=minlength)
        imag_bincount = x.bincount(weights=weights.imag,
                                   minlength=minlength)
        return real_bincount + imag_bincount
    else:
        return x.bincount(weights=weights, minlength=minlength)
Jait Dixit's avatar
Jait Dixit committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279


def get_default_codomain(domain):
    from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace
    from nifty.operators.fft_operator.transformations import RGRGTransformation, \
        HPLMTransformation, GLLMTransformation, LMGLTransformation

    if isinstance(domain, RGSpace):
        return RGRGTransformation.get_codomain(domain)
    elif isinstance(domain, HPSpace):
        return HPLMTransformation.get_codomain(domain)
    elif isinstance(domain, GLSpace):
        return GLLMTransformation.get_codomain(domain)
    elif isinstance(domain, LMSpace):
        # TODO: get the preferred transformation path from config
        return LMGLTransformation.get_codomain(domain)
    else:
280
        raise TypeError('ERROR: unknown domain')