utilities.py 13.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Theo Steininger's avatar
Theo Steininger committed
17

18
import collections
Martin Reinecke's avatar
Martin Reinecke committed
19
from functools import reduce
Philipp Arras's avatar
Philipp Arras committed
20
from itertools import product
21
22
23

import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
24
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
Martin Reinecke's avatar
Martin Reinecke committed
25
           "memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
26
           "my_lincomb", "indent",
27
           "my_product", "frozendict", "special_add_at", "iscomplextype",
28
29
           "value_reshaper", "lognormal_moments",
           "check_MPI_equality", "check_MPI_synced_random_state"]
30
31


Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
32
33
def my_sum(iterable):
    return reduce(lambda x, y: x+y, iterable)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


def my_lincomb_simple(terms, factors):
    terms2 = map(lambda v: v[0]*v[1], zip(terms, factors))
    return my_sum(terms2)


def my_lincomb(terms, factors):
    terms2 = map(lambda v: v[0] if v[1] == 1. else v[0]*v[1],
                 zip(terms, factors))
    return my_sum(terms2)


def my_product(iterable):
    return reduce(lambda x, y: x*y, iterable)
Martin Reinecke's avatar
Martin Reinecke committed
49

50

51
52
def get_slice_list(shape, axes):
    """
Theo Steininger's avatar
Theo Steininger committed
53
54
    Helper function which generates slice list(s) to traverse over all
    combinations of axes, other than the selected axes.
Jait Dixit's avatar
Jait Dixit committed
55
56
57
58

    Parameters
    ----------
    shape: tuple
Theo Steininger's avatar
Theo Steininger committed
59
        Shape of the data array to traverse over.
Jait Dixit's avatar
Jait Dixit committed
60
    axes: tuple
Theo Steininger's avatar
Theo Steininger committed
61
        Axes which should not be iterated over.
Jait Dixit's avatar
Jait Dixit committed
62

Martin Reinecke's avatar
Martin Reinecke committed
63
64
    Yields
    ------
Jait Dixit's avatar
Jait Dixit committed
65
66
67
68
69
70
71
72
    list
        The next list of indices and/or slice objects for each dimension.

    Raises
    ------
    ValueError
        If shape is empty.
        If axes(axis) does not match shape.
73
    """
Martin Reinecke's avatar
Martin Reinecke committed
74
    if shape is None:
75
        raise ValueError("shape cannot be None.")
76

77
78
    if axes:
        if not all(axis < len(shape) for axis in axes):
79
            raise ValueError("axes(axis) does not match shape.")
Martin Reinecke's avatar
Martin Reinecke committed
80
        axes_select = [0 if x in axes else 1 for x in range(len(shape))]
Jait Dixit's avatar
Jait Dixit committed
81
        axes_iterables = \
Martin Reinecke's avatar
Martin Reinecke committed
82
            [list(range(y)) for x, y in enumerate(shape) if x not in axes]
83
84
        for index in product(*axes_iterables):
            it_iter = iter(index)
85
            slice_list = tuple(
86
87
                next(it_iter)
                if axis else slice(None, None) for axis in axes_select
88
            )
89
90
91
            yield slice_list
    else:
        yield [slice(None, None)]
Theo Steininger's avatar
Theo Steininger committed
92

Theo Steininger's avatar
Theo Steininger committed
93

94
95
96
97
98
99
100
def safe_cast(tfunc, val):
    tmp = tfunc(val)
    if val != tmp:
        raise ValueError("value changed during cast")
    return tmp


Martin Reinecke's avatar
Martin Reinecke committed
101
102
def parse_spaces(spaces, nspc):
    nspc = safe_cast(int, nspc)
103
    if spaces is None:
Martin Reinecke's avatar
Martin Reinecke committed
104
        return tuple(range(nspc))
105
106
107
108
    elif np.isscalar(spaces):
        spaces = (safe_cast(int, spaces),)
    else:
        spaces = tuple(safe_cast(int, item) for item in spaces)
109
110
    if len(spaces) == 0:
        return spaces
111
    tmp = tuple(set(spaces))
Martin Reinecke's avatar
Martin Reinecke committed
112
    if tmp[0] < 0 or tmp[-1] >= nspc:
113
114
115
116
        raise ValueError("space index out of range")
    if len(tmp) != len(spaces):
        raise ValueError("multiply defined space indices")
    return spaces
Martin Reinecke's avatar
Martin Reinecke committed
117
118


119
120
121
def infer_space(domain, space):
    if space is None:
        if len(domain) != 1:
122
123
            raise ValueError("'space' index must be given for objects based on"
                             " DomainTuples containing more than one domain")
124
125
126
127
128
129
130
        space = 0
    space = int(space)
    if space < 0 or space >= len(domain):
        raise ValueError("space index out of range")
    return space


Martin Reinecke's avatar
Martin Reinecke committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def memo(f):
    name = f.__name__

    def wrapped_f(self):
        if not hasattr(self, "_cache"):
            self._cache = {}
        try:
            return self._cache[name]
        except KeyError:
            self._cache[name] = f(self)
            return self._cache[name]
    return wrapped_f


class _DocStringInheritor(type):
    """
    A variation on
Martin Reinecke's avatar
Martin Reinecke committed
148
    https://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
Martin Reinecke's avatar
Martin Reinecke committed
149
150
151
152
153
154
155
156
157
158
    by Paul McGuire
    """
    def __new__(meta, name, bases, clsdict):
        if not('__doc__' in clsdict and clsdict['__doc__']):
            for mro_cls in (mro_cls for base in bases
                            for mro_cls in base.mro()):
                doc = mro_cls.__doc__
                if doc:
                    clsdict['__doc__'] = doc
                    break
Martin Reinecke's avatar
Martin Reinecke committed
159
        for attr, attribute in clsdict.items():
Martin Reinecke's avatar
Martin Reinecke committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            if not attribute.__doc__:
                for mro_cls in (mro_cls for base in bases
                                for mro_cls in base.mro()
                                if hasattr(mro_cls, attr)):
                    doc = getattr(getattr(mro_cls, attr), '__doc__')
                    if doc:
                        if isinstance(attribute, property):
                            clsdict[attr] = property(attribute.fget,
                                                     attribute.fset,
                                                     attribute.fdel,
                                                     doc)
                        else:
                            attribute.__doc__ = doc
                        break
        return super(_DocStringInheritor, meta).__new__(meta, name,
                                                        bases, clsdict)


Martin Reinecke's avatar
Martin Reinecke committed
178
class NiftyMeta(_DocStringInheritor):
Martin Reinecke's avatar
Martin Reinecke committed
179
    pass
Martin Reinecke's avatar
Martin Reinecke committed
180
181


Martin Reinecke's avatar
Martin Reinecke committed
182
class frozendict(collections.abc.Mapping):
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    """
    An immutable wrapper around dictionaries that implements the complete
    :py:class:`collections.Mapping` interface. It can be used as a drop-in
    replacement for dictionaries where immutability is desired.
    """

    dict_cls = dict

    def __init__(self, *args, **kwargs):
        self._dict = self.dict_cls(*args, **kwargs)
        self._hash = None

    def __getitem__(self, key):
        return self._dict[key]

    def __contains__(self, key):
        return key in self._dict

    def copy(self, **add_or_replace):
        return self.__class__(self, **add_or_replace)

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        return len(self._dict)

    def __repr__(self):
Martin Reinecke's avatar
Martin Reinecke committed
211
        return '<{} {}>'.format(self.__class__.__name__, self._dict)
212
213
214
215
216
217
218
219

    def __hash__(self):
        if self._hash is None:
            h = 0
            for key, value in self._dict.items():
                h ^= hash((key, value))
            self._hash = h
        return self._hash
Martin Reinecke's avatar
Martin Reinecke committed
220
221


Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
222
223
224
def special_add_at(a, axis, index, b):
    if a.dtype != b.dtype:
        raise TypeError("data type mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
225
226
227
228
    sz1 = int(np.prod(a.shape[:axis]))
    sz3 = int(np.prod(a.shape[axis+1:]))
    a2 = a.reshape([sz1, -1, sz3])
    b2 = b.reshape([sz1, -1, sz3])
Martin Reinecke's avatar
Martin Reinecke committed
229
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
230
231
232
233
234
235
        dt2 = a.real.dtype
        a2 = a2.view(dt2)
        b2 = b2.view(dt2)
        sz3 *= 2
    for i1 in range(sz1):
        for i3 in range(sz3):
Martin Reinecke's avatar
Martin Reinecke committed
236
237
            a2[i1, :, i3] += np.bincount(index, b2[i1, :, i3],
                                         minlength=a2.shape[1])
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
238

Martin Reinecke's avatar
Martin Reinecke committed
239
    if iscomplextype(a.dtype):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
240
241
        a2 = a2.view(a.dtype)
    return a2.reshape(a.shape)
Martin Reinecke's avatar
Martin Reinecke committed
242
243
244


_iscomplex_tpl = (np.complex64, np.complex128)
245
246


Martin Reinecke's avatar
Martin Reinecke committed
247
248
def iscomplextype(dtype):
    return dtype.type in _iscomplex_tpl
249
250
251
252


def indent(inp):
    return "\n".join((("  "+s).rstrip() for s in inp.splitlines()))
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281


def shareRange(nwork, nshares, myshare):
    """Divides a number of work items as fairly as possible into a given number
    of shares.

    Parameters
    ----------
    nwork: int
        number of work items
    nshares: int
        number of shares among which the work should be distributed
    myshare: int
        the share for which the range of work items is requested


    Returns
    -------
    lo, hi: int
        index range of work items for this share
    """

    nbase = nwork//nshares
    additional = nwork % nshares
    lo = myshare*nbase + min(myshare, additional)
    hi = lo + nbase + int(myshare < additional)
    return lo, hi


Martin Reinecke's avatar
Martin Reinecke committed
282
283
284
285
286
287
288
289
290
291

def get_MPI_params_from_comm(comm):
    if comm is None:
        return 1, 0, True
    size = comm.Get_size()
    rank = comm.Get_rank()
    return size, rank, rank == 0



292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def get_MPI_params():
    """Returns basic information about the MPI setup of the running script.

    Returns
    -------
    comm: MPI communicator or None
        if MPI is detected _and_ more than one task is active, returns
        the world communicator, else returns None
    size: int
        the number of tasks running in total
    rank: int
        the rank of this task
    master: bool
        True if rank == 0, else False
    """

    try:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        size = comm.Get_size()
        if size == 1:
            return None, 1, 0, True
        rank = comm.Get_rank()
        return comm, size, rank, rank == 0
    except ImportError:
        return None, 1, 0, True


def allreduce_sum(obj, comm):
    """ This is a deterministic implementation of MPI allreduce

    Numeric addition is not associative due to rounding errors.
    Therefore we provide our own implementation that is consistent
    no matter if MPI is used and how many tasks there are.

    At the beginning, a list `who` is constructed, that states which obj can
    be found on which MPI task.
    Then elements are added pairwise, with increasing pair distance.
    In the first round, the distance between pair members is 1:
      v[0] := v[0] + v[1]
      v[2] := v[2] + v[3]
      v[4] := v[4] + v[5]
    Entries whose summation partner lies beyond the end of the array
    stay unchanged.
    When both summation partners are not located on the same MPI task,
    the second summand is sent to the task holding the first summand and
    the operation is carried out there.
    For the next round, the distance is doubled:
      v[0] := v[0] + v[2]
      v[4] := v[4] + v[6]
      v[8] := v[8] + v[10]
    This is repeated until the distance exceeds the length of the array.
    At this point v[0] contains the sum of all entries, which is then
    broadcast to all tasks.
    """
    vals = list(obj)
    if comm is None:
        nobj = len(vals)
        who = np.zeros(nobj, dtype=np.int32)
        rank = 0
    else:
        ntask = comm.Get_size()
        rank = comm.Get_rank()
        nobj_list = comm.allgather(len(vals))
        all_hi = list(np.cumsum(nobj_list))
        all_lo = [0] + all_hi[:-1]
        nobj = all_hi[-1]
        rank_lo_hi = [(l, h) for l, h in zip(all_lo, all_hi)]
        lo, hi = rank_lo_hi[rank]
        vals = [None]*lo + vals + [None]*(nobj-hi)
        who = [t for t, (l, h) in enumerate(rank_lo_hi) for cnt in range(h-l)]

    step = 1
    while step < nobj:
        for j in range(0, nobj, 2*step):
            if j+step < nobj:  # summation partner found
                if rank == who[j]:
                    if who[j] == who[j+step]:  # no communication required
                        vals[j] = vals[j] + vals[j+step]
                        vals[j+step] = None
                    else:
                        vals[j] = vals[j] + comm.recv(source=who[j+step])
                elif rank == who[j+step]:
                    comm.send(vals[j+step], dest=who[j])
                    vals[j+step] = None
        step *= 2
    if comm is None:
        return vals[0]
    return comm.bcast(vals[0], root=who[0])
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398


def value_reshaper(x, N):
    """Produce arrays of shape `(N,)`.
    If `x` is a scalar or array of length one, fill the target array with it.
    If `x` is an array, check if it has the right shape."""
    x = np.asfarray(x)
    if x.shape in [(), (1, )]:
        return np.full(N, x) if N != 0 else x.reshape(())
    elif x.shape == (N, ):
        return x
    raise TypeError("x and N are incompatible")


def lognormal_moments(mean, sigma, N=0):
    """Calculates the parameters for a normal distribution `n(x)`
    such that `exp(n)(x)` has the mean and standard deviation given.

Philipp Arras's avatar
Philipp Arras committed
399
    Used in :func:`~nifty7.normal_operators.LognormalTransform`."""
400
401
402
403
404
405
406
407
408
    mean, sigma = (value_reshaper(param, N) for param in (mean, sigma))
    if not np.all(mean > 0):
        raise ValueError("mean must be greater 0; got {!r}".format(mean))
    if not np.all(sigma > 0):
        raise ValueError("sig must be greater 0; got {!r}".format(sigma))

    logsigma = np.sqrt(np.log1p((sigma / mean)**2))
    logmean = np.log(mean) - logsigma**2 / 2
    return logmean, logsigma
409
410
411
412
413
414
415


def myassert(val):
    """Safe alternative to python's assert statement which is active even if
    `__debug__` is False."""
    if not val:
        raise AssertionError
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432


def check_MPI_equality(obj, comm):
    """Check that object is the same on all MPI tasks associated to a given
    communicator.

    Raises a RuntimeError if it differs.

    Parameters
    ----------
    obj :
        Any Python object that implements __eq__.
    comm : MPI communicator or None
        If comm is None, no check will be performed
    """
    if comm is None:
        return
Philipp Arras's avatar
Philipp Arras committed
433
434
    if not _MPI_unique(obj, comm):
        raise RuntimeError("MPI tasks are not in sync")
435
436
437


def _MPI_unique(obj, comm):
Philipp Arras's avatar
Philipp Arras committed
438
439
440
441
442
443
    from collections.abc import Hashable
    import pickle
    objects = comm.allgather(obj)
    if not isinstance(objects[0], Hashable) or isinstance(objects[0], np.random.SeedSequence):
        objects = [pickle.dumps(oo) for oo in objects]
    return len(set(objects)) == 1
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460


def check_MPI_synced_random_state(comm):
    """Check that random state is the same on all MPI tasks associated to a
    given communicator.

    Raises a RuntimeError if it differs.

    Parameters
    ----------
    comm : MPI communicator or None
        If comm is None, no check will be performed
    """
    from .random import getState
    if comm is None:
        return
    check_MPI_equality(getState(), comm)