nifty_utilities.py 9.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Ultima's avatar
Ultima committed
18
19

import numpy as np
20
21
from itertools import product

22

23
24
def get_slice_list(shape, axes):
    """
theos's avatar
theos committed
25
26
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
27
28
29
30

    Parameters
    ----------
    shape: tuple
theos's avatar
theos committed
31
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
32
    axes: tuple
theos's avatar
theos committed
33
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
34
35
36
37
38
39
40
41
42
43
44
45

    Yields
    -------
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
    ValueError
        If axes(axis) does not match shape.
46
    """
theos's avatar
theos committed
47

48
    if not shape:
49
        raise ValueError("shape cannot be None.")
50

51
52
    if axes:
        if not all(axis < len(shape) for axis in axes):
53
            raise ValueError("axes(axis) does not match shape.")
54
        axes_select = [0 if x in axes else 1 for x, y in enumerate(shape)]
Jait Dixit's avatar
Jait Dixit committed
55
        axes_iterables = \
56
57
58
59
60
61
            [range(y) for x, y in enumerate(shape) if x not in axes]
        for index in product(*axes_iterables):
            it_iter = iter(index)
            slice_list = [
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
Jait Dixit's avatar
Jait Dixit committed
62
                ]
63
64
65
66
            yield slice_list
    else:
        yield [slice(None, None)]
        return
Ultima's avatar
Ultima committed
67

Ultima's avatar
Ultima committed
68

69
def hermitianize_gaussian(x, axes=None):
Ultima's avatar
Ultima committed
70
    # make the point inversions
71
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
72
    flipped_x = flipped_x.conjugate()
Ultima's avatar
Ultima committed
73
    # check if x was already hermitian
Ultima's avatar
Ultima committed
74
75
    if (x == flipped_x).all():
        return x
Ultima's avatar
Ultima committed
76
77
    # average x and flipped_x.
    # Correct the variance by multiplying sqrt(0.5)
Ultima's avatar
Ultima committed
78
    x = (x + flipped_x) * np.sqrt(0.5)
Ultima's avatar
Ultima committed
79
80
81
    # The fixed points of the point inversion must not be avaraged.
    # Hence one must multiply them again with sqrt(0.5)
    # -> Get the middle index of the array
Jait Dixit's avatar
Jait Dixit committed
82
    mid_index = np.array(x.shape, dtype=np.int) // 2
Ultima's avatar
Ultima committed
83
    dimensions = mid_index.size
Ultima's avatar
Ultima committed
84
85
    # Use ndindex to iterate over all combinations of zeros and the
    # mid_index in order to correct all fixed points.
86
87
88
89
90
91
    if axes is None:
        axes = xrange(dimensions)

    ndlist = [2 if i in axes else 1 for i in xrange(dimensions)]
    ndlist = tuple(ndlist)
    for i in np.ndindex(ndlist):
Jait Dixit's avatar
Jait Dixit committed
92
        temp_index = tuple(i * mid_index)
Ultima's avatar
Ultima committed
93
94
95
96
97
        x[temp_index] *= np.sqrt(0.5)
    try:
        x.hermitian = True
    except(AttributeError):
        pass
98

Ultima's avatar
Ultima committed
99
100
101
    return x


102
def hermitianize(x, axes=None):
Ultima's avatar
Ultima committed
103
    # make the point inversions
104
    flipped_x = _hermitianize_inverter(x, axes=axes)
Ultima's avatar
Ultima committed
105
    flipped_x = flipped_x.conjugate()
106

Ultima's avatar
Ultima committed
107
    # average x and flipped_x.
108
109
110
    # x = (x + flipped_x) / 2.
    result_x = x + flipped_x
    result_x /= 2.
Ultima's avatar
Ultima committed
111

112
113
114
115
116
117
#    try:
#        x.hermitian = True
#    except(AttributeError):
#        pass

    return result_x
Ultima's avatar
Ultima committed
118

Ultima's avatar
Ultima committed
119

120
def _hermitianize_inverter(x, axes):
Ultima's avatar
Ultima committed
121
    # calculate the number of dimensions the input array has
Ultima's avatar
Ultima committed
122
    dimensions = len(x.shape)
Ultima's avatar
Ultima committed
123
    # prepare the slicing object which will be used for mirroring
Jait Dixit's avatar
Jait Dixit committed
124
    slice_primitive = [slice(None), ] * dimensions
Ultima's avatar
Ultima committed
125
    # copy the input data
Ultima's avatar
Ultima committed
126
    y = x.copy()
127
128
129
130
131
132

    if axes is None:
        axes = xrange(dimensions)

    # flip in the desired directions
    for i in axes:
Ultima's avatar
Ultima committed
133
        slice_picker = slice_primitive[:]
Ultima's avatar
Ultima committed
134
        slice_picker[i] = slice(1, None, None)
ultimanet's avatar
ultimanet committed
135
136
        slice_picker = tuple(slice_picker)

Ultima's avatar
Ultima committed
137
138
        slice_inverter = slice_primitive[:]
        slice_inverter[i] = slice(None, 0, -1)
ultimanet's avatar
ultimanet committed
139
        slice_inverter = tuple(slice_inverter)
Ultima's avatar
Ultima committed
140
141

        try:
142
143
            y.set_data(to_key=slice_picker, data=y,
                       from_key=slice_inverter)
Ultima's avatar
Ultima committed
144
145
146
        except(AttributeError):
            y[slice_picker] = y[slice_inverter]
    return y
147
148


149
def direct_vdot(x, y):
Ultima's avatar
Ultima committed
150
    # the input could be fields. Try to extract the data
Ultima's avatar
Ultima committed
151
152
153
154
155
156
157
158
    try:
        x = x.get_val()
    except(AttributeError):
        pass
    try:
        y = y.get_val()
    except(AttributeError):
        pass
Ultima's avatar
Ultima committed
159
    # try to make a direct vdot
Ultima's avatar
Ultima committed
160
161
162
163
    try:
        return x.vdot(y)
    except(AttributeError):
        pass
164

Ultima's avatar
Ultima committed
165
166
167
    try:
        return y.vdot(x)
    except(AttributeError):
168
169
        pass

Ultima's avatar
Ultima committed
170
    # fallback to numpy
171
172
    return np.vdot(x, y)

Ultima's avatar
Ultima committed
173
174

def convert_nested_list_to_object_array(x):
Ultima's avatar
Ultima committed
175
176
177
    # if x is a nested_list full of ndarrays all having the same size,
    # np.shape returns the shape of the ndarrays, too, i.e. too many
    # dimensions
Ultima's avatar
Ultima committed
178
    possible_shape = np.shape(x)
Ultima's avatar
Ultima committed
179
    # Check if possible_shape goes too deep.
180
    dimension_counter = 0
Ultima's avatar
Ultima committed
181
182
    current_extract = x
    for i in xrange(len(possible_shape)):
183
184
        if not isinstance(current_extract, list) and \
                not isinstance(current_extract, tuple):
Ultima's avatar
Ultima committed
185
186
187
188
            break
        current_extract = current_extract[0]
        dimension_counter += 1
    real_shape = possible_shape[:dimension_counter]
Ultima's avatar
Ultima committed
189
    # if the numpy array was not encapsulated at all, return x directly
Ultima's avatar
Ultima committed
190
191
    if real_shape == ():
        return x
Ultima's avatar
Ultima committed
192
193
    # Prepare the carrier-object
    carrier = np.empty(real_shape, dtype=np.object)
194
    for i in xrange(reduce(lambda x, y: x * y, real_shape)):
Ultima's avatar
Ultima committed
195
        ii = np.unravel_index(i, real_shape)
196
        try:
Ultima's avatar
Ultima committed
197
198
199
            carrier[ii] = x[ii]
        except(TypeError):
            extracted = x
200
            for j in xrange(len(ii)):
Ultima's avatar
Ultima committed
201
                extracted = extracted[ii[j]]
202
            carrier[ii] = extracted
Ultima's avatar
Ultima committed
203
204
205
206
207
208
209
210
211
    return carrier


def field_map(ishape, function, *args):
    if ishape == ():
        return function(*args)
    else:
        if args == ():
            result = np.empty(ishape, dtype=np.object)
212
            for i in xrange(reduce(lambda x, y: x * y, ishape)):
Ultima's avatar
Ultima committed
213
214
215
216
                ii = np.unravel_index(i, ishape)
                result[ii] = function()
            return result
        else:
Ultima's avatar
Ultima committed
217
218
219
220
            # define a helper function in order to clip the get-indices
            # to be suitable for the foreign arrays in args.
            # This allows you to do operations, like adding to fields
            # with ishape (3,4,3) and (3,4,1)
Ultima's avatar
Ultima committed
221
222
            def get_clipped(w, ind):
                w_shape = np.array(np.shape(w))
Jait Dixit's avatar
Jait Dixit committed
223
                get_tuple = tuple(np.clip(ind, 0, w_shape - 1))
Ultima's avatar
Ultima committed
224
                return w[get_tuple]
Jait Dixit's avatar
Jait Dixit committed
225

Ultima's avatar
Ultima committed
226
            result = np.empty_like(args[0])
227
            for i in xrange(reduce(lambda x, y: x * y, result.shape)):
Ultima's avatar
Ultima committed
228
                ii = np.unravel_index(i, result.shape)
229
230
231
232
233
                result[ii] = function(
                    *map(
                        lambda z: get_clipped(z, ii), args
                    )
                )
Ultima's avatar
Ultima committed
234
235
                # result[ii] = function(*map(lambda z: z[ii], args))
            return result
236
237
238
239
240
241
242
243
244


def cast_axis_to_tuple(axis, length):
    if axis is None:
        return None
    try:
        axis = tuple(int(item) for item in axis)
    except(TypeError):
        if np.isscalar(axis):
Jait Dixit's avatar
Jait Dixit committed
245
            axis = (int(axis),)
246
247
        else:
            raise TypeError(
248
                "Could not convert axis-input to tuple of ints")
249
250
251

    # shift negative indices to positive ones
    axis = tuple(item if (item >= 0) else (item + length) for item in axis)
252

253
    # Deactivated this, in order to allow for the ComposedOperator
254
    # remove duplicate entries
255
    # axis = tuple(set(axis))
256
257
258

    # assert that all entries are elements in [0, length]
    for elem in axis:
Jait Dixit's avatar
Jait Dixit committed
259
        assert (0 <= elem < length)
260

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    return axis


def complex_bincount(x, weights=None, minlength=None):
    try:
        complex_weights_Q = issubclass(weights.dtype.type,
                                       np.complexfloating)
    except AttributeError:
        complex_weights_Q = False

    if complex_weights_Q:
        real_bincount = x.bincount(weights=weights.real,
                                   minlength=minlength)
        imag_bincount = x.bincount(weights=weights.imag,
                                   minlength=minlength)
        return real_bincount + imag_bincount
    else:
        return x.bincount(weights=weights, minlength=minlength)
Jait Dixit's avatar
Jait Dixit committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295


def get_default_codomain(domain):
    from nifty.spaces import RGSpace, HPSpace, GLSpace, LMSpace
    from nifty.operators.fft_operator.transformations import RGRGTransformation, \
        HPLMTransformation, GLLMTransformation, LMGLTransformation

    if isinstance(domain, RGSpace):
        return RGRGTransformation.get_codomain(domain)
    elif isinstance(domain, HPSpace):
        return HPLMTransformation.get_codomain(domain)
    elif isinstance(domain, GLSpace):
        return GLLMTransformation.get_codomain(domain)
    elif isinstance(domain, LMSpace):
        # TODO: get the preferred transformation path from config
        return LMGLTransformation.get_codomain(domain)
    else:
296
        raise TypeError('ERROR: unknown domain')
297
298
299


def parse_domain(domain):
300
    from nifty.domain_object import DomainObject
301
302
    if domain is None:
        domain = ()
303
    elif isinstance(domain, DomainObject):
304
305
306
307
308
        domain = (domain,)
    elif not isinstance(domain, tuple):
        domain = tuple(domain)

    for d in domain:
309
        if not isinstance(d, DomainObject):
310
311
            raise TypeError(
                "Given object contains something that is not a "
312
                "instance of DomainObject-class.")
313
    return domain