Skip to content
Snippets Groups Projects

fix and stricter tests

Merged Martin Reinecke requested to merge randomcontext into NIFTy_6
Files
2
+ 37
4
@@ -149,8 +149,14 @@ def push_sseq(sseq):
Notes
-----
Make sure that every call to `push_sseq` has a matching call to
`pop_sseq`, otherwise the module's internal stack will grow indefinitely!
This function should only be used
- if you only want to change the random seed once at the very beginning
of a run, or
- if the restoring of the previous state has to happen in a different
Python function. In this case, please make sure that there is a matching
call to `pop_sseq` for every call to this function!
In all other situations, it is highly recommended to use the `Context`
class for managing the RNG state.
"""
_sseq.append(sseq)
_rng.append(np.random.default_rng(_sseq[-1]))
@@ -169,8 +175,14 @@ def push_sseq_from_seed(seed):
Notes
-----
Make sure that every call to `push_sseq_from_seed` has a matching call to
`pop_sseq`, otherwise the module's internal stack will grow indefinitely!
This function should only be used
- if you only want to change the random seed once at the very beginning
of a run, or
- if the restoring of the previous state has to happen in a different
Python function. In this case, please make sure that there is a matching
call to `pop_sseq` for every call to this function!
In all other situations, it is highly recommended to use the `Context`
class for managing the RNG state.
"""
_sseq.append(np.random.SeedSequence(seed))
_rng.append(np.random.default_rng(_sseq[-1]))
@@ -234,13 +246,34 @@ class Random(object):
class Context(object):
"""Convenience class for easy management of the RNG state.
Usage:
```
with ift.random.Context(seed|sseq):
<code using the new RNG state>
```
At the end of the scope, the original RNG state will be restored
automatically.
Parameters
----------
inp : int or numpy.random.SeedSequence
The starting information for the new RNG state.
If it is an integer, a new `SeedSequence` will be generated from it.
"""
def __init__(self, inp):
if not isinstance(inp, np.random.SeedSequence):
inp = np.random.SeedSequence(inp)
self._sseq = inp
def __enter__(self):
self._depth = len(_sseq)
push_sseq(self._sseq)
def __exit__(self, exc_type, exc_value, tb):
pop_sseq()
if self._depth != len(_sseq):
raise RuntimeError("inconsistent RNG usage detected")
return exc_type is None
Loading