random.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Philipp Arras's avatar
Philipp Arras committed
17

Martin Reinecke's avatar
Martin Reinecke committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
Some remarks on NIFTy's treatment of random numbers

NIFTy makes use of the `Generator` and `SeedSequence` classes introduced to
`numpy.random` in numpy 1.17.

On first load of the `nifty6.random` module, it creates a stack of
`SeedSequence` objects which contains a single `SeedSequence` with a fixed seed,
and also a stack of `Generator` objects, which contains a single generator
derived from the above seed sequence. Without user intervention, this generator
will be used for all random number generation tasks within NIFTy. This means

- that random numbers drawn by NIFTy will be reproducible across multiple runs
  (assuming there are no complications like MPI-enabled runs with a varying
  number of tasks), and
Lukas Platz's avatar
Lukas Platz committed
33

Martin Reinecke's avatar
Martin Reinecke committed
34
35
36
37
- that trying to change random seeds via `numpy.random.seed` will have no
  effect on the random numbers drawn by NIFTy.

Users who want to change the random seed for a given run can achieve this
Lukas Platz's avatar
Lukas Platz committed
38
39
40
41
42
by calling :func:`push_sseq_from_seed()` with a seed of their choice. This will
push a new seed sequence generated from that seed onto the seed sequence stack,
and a generator derived from this seed sequence onto the generator stack.
Since all NIFTy RNG-related calls will use the generator on the top of the stack,
all calls from this point on will use the new generator.
Martin Reinecke's avatar
Martin Reinecke committed
43
If the user already has a `SeedSequence` object at hand, they can pass this to
Lukas Platz's avatar
Lukas Platz committed
44
45
NIFTy via :func:`push_sseq`. A new generator derived from this sequence will then
also be pushed onto the generator stack.
Martin Reinecke's avatar
Martin Reinecke committed
46
These operations can be reverted (and should be, as soon as the new generator is
Lukas Platz's avatar
Lukas Platz committed
47
no longer needed) by a call to :func:`pop_sseq()`.
Reimar Leike's avatar
Reimar Leike committed
48
When users need direct access to the RNG currently in use, they can access it
Lukas Platz's avatar
Lukas Platz committed
49
via the :func:`current_rng` function.
Martin Reinecke's avatar
Martin Reinecke committed
50
51
52
53
54
55
56


Example for using multiple seed sequences:

Assume that N samples are needed to compute a KL, which are distributed over
a variable number of MPI tasks. In this situation, whenever random numbers
need to be drawn for these samples:
Lukas Platz's avatar
Lukas Platz committed
57

Martin Reinecke's avatar
Martin Reinecke committed
58
- each MPI task should spawn as many seed sequences as there are samples
Lukas Platz's avatar
Lukas Platz committed
59
  *in total*, using ``sseq = spawn_sseq(N)``
Lukas Platz's avatar
Lukas Platz committed
60

Martin Reinecke's avatar
Martin Reinecke committed
61
- each task loops over the local samples
Lukas Platz's avatar
Lukas Platz committed
62
63

  - first pushing the seed sequence for the **global** index of the
Lukas Platz's avatar
Lukas Platz committed
64
    current sample via ``push_sseq(sseq[iglob])```
Lukas Platz's avatar
Lukas Platz committed
65

Martin Reinecke's avatar
Martin Reinecke committed
66
  - drawing the required random numbers
Lukas Platz's avatar
Lukas Platz committed
67

Lukas Platz's avatar
Lukas Platz committed
68
  - then popping the seed sequence again via ``pop_sseq()``
Martin Reinecke's avatar
Martin Reinecke committed
69
70
71
72
73
74

That way, random numbers should be reproducible and independent of the number
of MPI tasks.

WARNING: do not push/pop the same `SeedSequence` object more than once - this
will lead to repeated random sequences! Whenever you have to push `SeedSequence`
Lukas Platz's avatar
Lukas Platz committed
75
objects, generate new ones via :func:`spawn_sseq()`.
Martin Reinecke's avatar
Martin Reinecke committed
76
77
"""

78
79
import numpy as np

Martin Reinecke's avatar
Martin Reinecke committed
80
81
82
# Stack of SeedSequence objects. Will always start out with a well-defined
# default. Users can change the "random seed" used by a calculation by pushing
# a different SeedSequence before invoking any other nifty6.random calls
Martin Reinecke's avatar
Martin Reinecke committed
83
_sseq = [np.random.SeedSequence(42)]
Martin Reinecke's avatar
Martin Reinecke committed
84
# Stack of random number generators associated with _sseq.
Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
_rng = [np.random.default_rng(_sseq[-1])]


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
def getState():
    """Returns the full internal state of the module. Intended for pickling.

    Returns
    -------
    state : unspecified
    """
    import pickle
    return pickle.dumps((_sseq, _rng))


def setState(state):
    """Restores the full internal state of the module. Intended for unpickling.


    Parameters
    ----------
    state : unspecified
        Result of an earlier call to `getState`.
    """
    import pickle
    global _sseq, _rng
    _sseq, _rng = pickle.loads(state)


Martin Reinecke's avatar
Martin Reinecke committed
113
def spawn_sseq(n, parent=None):
Martin Reinecke's avatar
Martin Reinecke committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    """Returns a list of `n` SeedSequence objects which are children of `parent`

    Parameters
    ----------
    n : int
        number of requested SeedSequence objects
    parent : SeedSequence
        the object from which the returned objects will be derived
        If `None`, the top of the current SeedSequence stack will be used

    Returns
    -------
    list(SeedSequence)
        the requested SeedSequence objects
    """
Martin Reinecke's avatar
Martin Reinecke committed
129
130
131
132
133
134
135
    if parent is None:
        global _sseq
        parent = _sseq[-1]
    return parent.spawn(n)


def current_rng():
Martin Reinecke's avatar
Martin Reinecke committed
136
137
138
139
140
141
142
    """Returns the RNG object currently in use by NIFTy

    Returns
    -------
    Generator
        the current Generator object (top of the generatir stack)
    """
Martin Reinecke's avatar
Martin Reinecke committed
143
144
145
146
    return _rng[-1]


def push_sseq(sseq):
Martin Reinecke's avatar
Martin Reinecke committed
147
148
149
150
151
152
153
154
    """Pushes a new SeedSequence object onto the SeedSequence stack.
    This also pushes a new Generator object built from the new SeedSequence
    to the generator stack.

    Parameters
    ----------
    sseq: SeedSequence
        the SeedSequence object to be used from this point
Martin Reinecke's avatar
Martin Reinecke committed
155
156
157

    Notes
    -----
158
    This function should only be used
Lukas Platz's avatar
Lukas Platz committed
159

160
161
    - if you only want to change the random seed once at the very beginning
      of a run, or
Lukas Platz's avatar
Lukas Platz committed
162

163
164
    - if the restoring of the previous state has to happen in a different
      Python function. In this case, please make sure that there is a matching
Lukas Platz's avatar
Lukas Platz committed
165
166
167
168
      call to :func:`pop_sseq` for every call to this function!

    In all other situations, it is highly recommended to use the
    :class:`Context` class for managing the RNG state.
Martin Reinecke's avatar
Martin Reinecke committed
169
    """
Martin Reinecke's avatar
Martin Reinecke committed
170
171
172
173
174
    _sseq.append(sseq)
    _rng.append(np.random.default_rng(_sseq[-1]))


def push_sseq_from_seed(seed):
Martin Reinecke's avatar
Martin Reinecke committed
175
176
177
178
179
180
181
182
183
    """Pushes a new SeedSequence object derived from an integer seed onto the
    SeedSequence stack.
    This also pushes a new Generator object built from the new SeedSequence
    to the generator stack.

    Parameters
    ----------
    seed: int
        the seed from which the new SeedSequence will be built
Martin Reinecke's avatar
Martin Reinecke committed
184
185
186

    Notes
    -----
187
    This function should only be used
Lukas Platz's avatar
Lukas Platz committed
188

189
190
    - if you only want to change the random seed once at the very beginning
      of a run, or
Lukas Platz's avatar
Lukas Platz committed
191

192
193
    - if the restoring of the previous state has to happen in a different
      Python function. In this case, please make sure that there is a matching
Lukas Platz's avatar
Lukas Platz committed
194
195
196
197
      call to :func:`pop_sseq` for every call to this function!

    In all other situations, it is highly recommended to use the
    :class:`Context` class for managing the RNG state.
Martin Reinecke's avatar
Martin Reinecke committed
198
    """
Martin Reinecke's avatar
Martin Reinecke committed
199
200
201
202
203
    _sseq.append(np.random.SeedSequence(seed))
    _rng.append(np.random.default_rng(_sseq[-1]))


def pop_sseq():
Martin Reinecke's avatar
Martin Reinecke committed
204
    """Pops the top of the SeedSequence and generator stacks."""
Martin Reinecke's avatar
Martin Reinecke committed
205
206
207
    _sseq.pop()
    _rng.pop()

208

209
210
class Random(object):
    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
211
    def pm1(dtype, shape):
212
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
213
            x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
214
            x = x[_rng[-1].integers(0, 4, size=shape)]
215
        else:
Martin Reinecke's avatar
Martin Reinecke committed
216
            x = 2*_rng[-1].integers(0, 2, size=shape)-1
Martin Reinecke's avatar
Martin Reinecke committed
217
        return x.astype(dtype, copy=False)
218
219

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
220
    def normal(dtype, shape, mean=0., std=1.):
Martin Reinecke's avatar
Martin Reinecke committed
221
222
223
        if not (np.issubdtype(dtype, np.floating) or
                np.issubdtype(dtype, np.complexfloating)):
            raise TypeError("dtype must be float or complex")
224
225
226
227
        if not np.isscalar(mean) or not np.isscalar(std):
            raise TypeError("mean and std must be scalars")
        if np.issubdtype(type(std), np.complexfloating):
            raise TypeError("std must not be complex")
Martin Reinecke's avatar
Martin Reinecke committed
228
229
        if ((not np.issubdtype(dtype, np.complexfloating)) and
                np.issubdtype(type(mean), np.complexfloating)):
230
            raise TypeError("mean must not be complex for a real result field")
231
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
232
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
233
234
            x.real = _rng[-1].normal(mean.real, std*np.sqrt(0.5), shape)
            x.imag = _rng[-1].normal(mean.imag, std*np.sqrt(0.5), shape)
235
        else:
Martin Reinecke's avatar
Martin Reinecke committed
236
            x = _rng[-1].normal(mean, std, shape).astype(dtype, copy=False)
237
238
239
        return x

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
240
    def uniform(dtype, shape, low=0., high=1.):
241
242
        if not np.isscalar(low) or not np.isscalar(high):
            raise TypeError("low and high must be scalars")
Martin Reinecke's avatar
Martin Reinecke committed
243
        if (np.issubdtype(type(low), np.complexfloating) or
244
245
                np.issubdtype(type(high), np.complexfloating)):
            raise TypeError("low and high must not be complex")
246
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
bug fix    
Martin Reinecke committed
247
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
248
249
            x.real = _rng[-1].uniform(low, high, shape)
            x.imag = _rng[-1].uniform(low, high, shape)
250
        elif np.issubdtype(dtype, np.integer):
Martin Reinecke's avatar
Martin Reinecke committed
251
            if not (np.issubdtype(type(low), np.integer) and
252
253
                    np.issubdtype(type(high), np.integer)):
                raise TypeError("low and high must be integer")
Martin Reinecke's avatar
Martin Reinecke committed
254
            x = _rng[-1].integers(low, high+1, shape)
255
        else:
Martin Reinecke's avatar
Martin Reinecke committed
256
            x = _rng[-1].uniform(low, high, shape)
Martin Reinecke's avatar
Martin Reinecke committed
257
        return x.astype(dtype, copy=False)
258
259
260


class Context(object):
261
    """Convenience class for easy management of the RNG state.
Lukas Platz's avatar
Lukas Platz committed
262
263
264
265
    Usage: ::

        with ift.random.Context(seed|sseq):
            code using the new RNG state
266

Martin Reinecke's avatar
typo    
Martin Reinecke committed
267
    At the end of the scope, the original RNG state will be restored
268
269
270
271
272
273
274
275
276
    automatically.

    Parameters
    ----------
    inp : int or numpy.random.SeedSequence
        The starting information for the new RNG state.
        If it is an integer, a new `SeedSequence` will be generated from it.
    """

277
278
279
280
281
282
    def __init__(self, inp):
        if not isinstance(inp, np.random.SeedSequence):
            inp = np.random.SeedSequence(inp)
        self._sseq = inp

    def __enter__(self):
283
        self._depth = len(_sseq)
284
285
286
        push_sseq(self._sseq)

    def __exit__(self, exc_type, exc_value, tb):
Martin Reinecke's avatar
Martin Reinecke committed
287
        pop_sseq()
288
289
        if self._depth != len(_sseq):
            raise RuntimeError("inconsistent RNG usage detected")
290
        return exc_type is None