nifty_utilities.py 8.56 KB
Newer Older
Ultima's avatar
Ultima committed
1
2
3
# -*- coding: utf-8 -*-

import numpy as np
4
5
from itertools import product

6
from nifty.config import about
7

8

9
10
def get_slice_list(shape, axes):
    """
theos's avatar
theos committed
11
12
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
13
14
15
16

    Parameters
    ----------
    shape: tuple
theos's avatar
theos committed
17
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
18
    axes: tuple
theos's avatar
theos committed
19
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
20
21
22
23
24
25
26
27
28
29
30
31

    Yields
    -------
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
    ValueError
        If axes(axis) does not match shape.
32
    """
theos's avatar
theos committed
33

34
35
36
    if not shape:
        raise ValueError(about._errors.cstring("ERROR: shape cannot be None."))

37
38
39
40
41
42
43
    if axes:
        if not all(axis < len(shape) for axis in axes):
            raise ValueError(
                about._errors.cstring("ERROR: \
                                      axes(axis) does not match shape.")
            )
        axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
Jait Dixit's avatar
Jait Dixit committed
44
        axes_iterables = \
45
46
47
48
49
50
            [range(y) for x, y in enumerate(shape) if x not in axes]
        for index in product(*axes_iterables):
            it_iter = iter(index)
            slice_list = [
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
Jait Dixit's avatar
Jait Dixit committed
51
                ]
52
53
54
55
            yield slice_list
    else:
        yield [slice(None, None)]
        return
Ultima's avatar
Ultima committed
56

Ultima's avatar
Ultima committed
57

58
def hermitianize_gaussian(x, axes=None):
Ultima's avatar
Ultima committed
59
    # make the point inversions
60
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
61
    flipped_x = flipped_x.conjugate()
Ultima's avatar
Ultima committed
62
    # check if x was already hermitian
Ultima's avatar
Ultima committed
63
64
    if (x == flipped_x).all():
        return x
Ultima's avatar
Ultima committed
65
66
    # average x and flipped_x.
    # Correct the variance by multiplying sqrt(0.5)
Ultima's avatar
Ultima committed
67
    x = (x + flipped_x) * np.sqrt(0.5)
Ultima's avatar
Ultima committed
68
69
70
    # The fixed points of the point inversion must not be avaraged.
    # Hence one must multiply them again with sqrt(0.5)
    # -> Get the middle index of the array
Jait Dixit's avatar
Jait Dixit committed
71
    mid_index = np.array(x.shape, dtype=np.int) // 2
Ultima's avatar
Ultima committed
72
    dimensions = mid_index.size
Ultima's avatar
Ultima committed
73
74
    # Use ndindex to iterate over all combinations of zeros and the
    # mid_index in order to correct all fixed points.
75
76
77
78
79
80
    if axes is None:
        axes = xrange(dimensions)

    ndlist = [2 if i in axes else 1 for i in xrange(dimensions)]
    ndlist = tuple(ndlist)
    for i in np.ndindex(ndlist):
Jait Dixit's avatar
Jait Dixit committed
81
        temp_index = tuple(i * mid_index)
Ultima's avatar
Ultima committed
82
83
84
85
86
        x[temp_index] *= np.sqrt(0.5)
    try:
        x.hermitian = True
    except(AttributeError):
        pass
87

Ultima's avatar
Ultima committed
88
89
90
    return x


91
def hermitianize(x, axes=None):
Ultima's avatar
Ultima committed
92
    # make the point inversions
93
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
94
    flipped_x = flipped_x.conjugate()
95

Ultima's avatar
Ultima committed
96
    # average x and flipped_x.
97
98
99
    # x = (x + flipped_x) / 2.
    result_x = x + flipped_x
    result_x /= 2.
Ultima's avatar
Ultima committed
100

101
102
103
104
105
106
#    try:
#        x.hermitian = True
#    except(AttributeError):
#        pass

    return result_x
Ultima's avatar
Ultima committed
107

Ultima's avatar
Ultima committed
108

109
def _hermitianize_inverter(x, axes):
Ultima's avatar
Ultima committed
110
    # calculate the number of dimensions the input array has
Ultima's avatar
Ultima committed
111
    dimensions = len(x.shape)
Ultima's avatar
Ultima committed
112
    # prepare the slicing object which will be used for mirroring
Jait Dixit's avatar
Jait Dixit committed
113
    slice_primitive = [slice(None), ] * dimensions
Ultima's avatar
Ultima committed
114
    # copy the input data
Ultima's avatar
Ultima committed
115
    y = x.copy()
116
117
118
119
120
121

    if axes is None:
        axes = xrange(dimensions)

    # flip in the desired directions
    for i in axes:
Ultima's avatar
Ultima committed
122
        slice_picker = slice_primitive[:]
Ultima's avatar
Ultima committed
123
        slice_picker[i] = slice(1, None, None)
ultimanet's avatar
ultimanet committed
124
125
        slice_picker = tuple(slice_picker)

Ultima's avatar
Ultima committed
126
127
        slice_inverter = slice_primitive[:]
        slice_inverter[i] = slice(None, 0, -1)
ultimanet's avatar
ultimanet committed
128
        slice_inverter = tuple(slice_inverter)
Ultima's avatar
Ultima committed
129
130

        try:
131
132
            y.set_data(to_key=slice_picker, data=y,
                       from_key=slice_inverter)
Ultima's avatar
Ultima committed
133
134
135
        except(AttributeError):
            y[slice_picker] = y[slice_inverter]
    return y
136
137


138
def direct_vdot(x, y):
Ultima's avatar
Ultima committed
139
    # the input could be fields. Try to extract the data
Ultima's avatar
Ultima committed
140
141
142
143
144
145
146
147
    try:
        x = x.get_val()
    except(AttributeError):
        pass
    try:
        y = y.get_val()
    except(AttributeError):
        pass
Ultima's avatar
Ultima committed
148
    # try to make a direct vdot
Ultima's avatar
Ultima committed
149
150
151
152
    try:
        return x.vdot(y)
    except(AttributeError):
        pass
153

Ultima's avatar
Ultima committed
154
155
156
    try:
        return y.vdot(x)
    except(AttributeError):
157
158
        pass

Ultima's avatar
Ultima committed
159
    # fallback to numpy
160
161
    return np.vdot(x, y)

Ultima's avatar
Ultima committed
162
163

def convert_nested_list_to_object_array(x):
Ultima's avatar
Ultima committed
164
165
166
    # if x is a nested_list full of ndarrays all having the same size,
    # np.shape returns the shape of the ndarrays, too, i.e. too many
    # dimensions
Ultima's avatar
Ultima committed
167
    possible_shape = np.shape(x)
Ultima's avatar
Ultima committed
168
    # Check if possible_shape goes too deep.
169
    dimension_counter = 0
Ultima's avatar
Ultima committed
170
171
    current_extract = x
    for i in xrange(len(possible_shape)):
172
173
        if not isinstance(current_extract, list) and \
                not isinstance(current_extract, tuple):
Ultima's avatar
Ultima committed
174
175
176
177
            break
        current_extract = current_extract[0]
        dimension_counter += 1
    real_shape = possible_shape[:dimension_counter]
Ultima's avatar
Ultima committed
178
    # if the numpy array was not encapsulated at all, return x directly
Ultima's avatar
Ultima committed
179
180
    if real_shape == ():
        return x
Ultima's avatar
Ultima committed
181
182
    # Prepare the carrier-object
    carrier = np.empty(real_shape, dtype=np.object)
183
    for i in xrange(reduce(lambda x, y: x * y, real_shape)):
Ultima's avatar
Ultima committed
184
        ii = np.unravel_index(i, real_shape)
185
        try:
Ultima's avatar
Ultima committed
186
187
188
            carrier[ii] = x[ii]
        except(TypeError):
            extracted = x
189
            for j in xrange(len(ii)):
Ultima's avatar
Ultima committed
190
                extracted = extracted[ii[j]]
191
            carrier[ii] = extracted
Ultima's avatar
Ultima committed
192
193
194
195
196
197
198
199
200
    return carrier


def field_map(ishape, function, *args):
    if ishape == ():
        return function(*args)
    else:
        if args == ():
            result = np.empty(ishape, dtype=np.object)
201
            for i in xrange(reduce(lambda x, y: x * y, ishape)):
Ultima's avatar
Ultima committed
202
203
204
205
                ii = np.unravel_index(i, ishape)
                result[ii] = function()
            return result
        else:
Ultima's avatar
Ultima committed
206
207
208
209
            # define a helper function in order to clip the get-indices
            # to be suitable for the foreign arrays in args.
            # This allows you to do operations, like adding to fields
            # with ishape (3,4,3) and (3,4,1)
Ultima's avatar
Ultima committed
210
211
            def get_clipped(w, ind):
                w_shape = np.array(np.shape(w))
Jait Dixit's avatar
Jait Dixit committed
212
                get_tuple = tuple(np.clip(ind, 0, w_shape - 1))
Ultima's avatar
Ultima committed
213
                return w[get_tuple]
Jait Dixit's avatar
Jait Dixit committed
214

Ultima's avatar
Ultima committed
215
            result = np.empty_like(args[0])
216
            for i in xrange(reduce(lambda x, y: x * y, result.shape)):
Ultima's avatar
Ultima committed
217
                ii = np.unravel_index(i, result.shape)
218
219
220
221
222
                result[ii] = function(
                    *map(
                        lambda z: get_clipped(z, ii), args
                    )
                )
Ultima's avatar
Ultima committed
223
224
                # result[ii] = function(*map(lambda z: z[ii], args))
            return result
225
226
227
228
229
230
231
232
233


def cast_axis_to_tuple(axis, length):
    if axis is None:
        return None
    try:
        axis = tuple(int(item) for item in axis)
    except(TypeError):
        if np.isscalar(axis):
Jait Dixit's avatar
Jait Dixit committed
234
            axis = (int(axis),)
235
236
237
238
239
240
        else:
            raise TypeError(
                "ERROR: Could not convert axis-input to tuple of ints")

    # shift negative indices to positive ones
    axis = tuple(item if (item >= 0) else (item + length) for item in axis)
241
242
243
244
245
246

    # remove duplicate entries
    axis = tuple(set(axis))

    # assert that all entries are elements in [0, length]
    for elem in axis:
Jait Dixit's avatar
Jait Dixit committed
247
        assert (0 <= elem < length)
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    return axis


def complex_bincount(x, weights=None, minlength=None):
    try:
        complex_weights_Q = issubclass(weights.dtype.type,
                                       np.complexfloating)
    except AttributeError:
        complex_weights_Q = False

    if complex_weights_Q:
        real_bincount = x.bincount(weights=weights.real,
                                   minlength=minlength)
        imag_bincount = x.bincount(weights=weights.imag,
                                   minlength=minlength)
        return real_bincount + imag_bincount
    else:
        return x.bincount(weights=weights, minlength=minlength)
Jait Dixit's avatar
Jait Dixit committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284


def get_default_codomain(domain):
    from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace
    from nifty.operators.fft_operator.transformations import RGRGTransformation, \
        HPLMTransformation, GLLMTransformation, LMGLTransformation

    if isinstance(domain, RGSpace):
        return RGRGTransformation.get_codomain(domain)
    elif isinstance(domain, HPSpace):
        return HPLMTransformation.get_codomain(domain)
    elif isinstance(domain, GLSpace):
        return GLLMTransformation.get_codomain(domain)
    elif isinstance(domain, LMSpace):
        # TODO: get the preferred transformation path from config
        return LMGLTransformation.get_codomain(domain)
    else:
        raise TypeError(about._errors.cstring('ERROR: unknown domain'))