Commit 8a3b5d58 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'randomcontext' into 'NIFTy_6'

fix and stricter tests

See merge request !438
parents 9961a7f7 97987ea3
Pipeline #71976 passed with stages
in 19 minutes and 55 seconds
......@@ -149,8 +149,14 @@ def push_sseq(sseq):
Notes
-----
Make sure that every call to `push_sseq` has a matching call to
`pop_sseq`, otherwise the module's internal stack will grow indefinitely!
This function should only be used
- if you only want to change the random seed once at the very beginning
of a run, or
- if the restoring of the previous state has to happen in a different
Python function. In this case, please make sure that there is a matching
call to `pop_sseq` for every call to this function!
In all other situations, it is highly recommended to use the `Context`
class for managing the RNG state.
"""
_sseq.append(sseq)
_rng.append(np.random.default_rng(_sseq[-1]))
......@@ -169,8 +175,14 @@ def push_sseq_from_seed(seed):
Notes
-----
Make sure that every call to `push_sseq_from_seed` has a matching call to
`pop_sseq`, otherwise the module's internal stack will grow indefinitely!
This function should only be used
- if you only want to change the random seed once at the very beginning
of a run, or
- if the restoring of the previous state has to happen in a different
Python function. In this case, please make sure that there is a matching
call to `pop_sseq` for every call to this function!
In all other situations, it is highly recommended to use the `Context`
class for managing the RNG state.
"""
_sseq.append(np.random.SeedSequence(seed))
_rng.append(np.random.default_rng(_sseq[-1]))
......@@ -234,13 +246,34 @@ class Random(object):
class Context(object):
"""Convenience class for easy management of the RNG state.
Usage:
```
with ift.random.Context(seed|sseq):
<code using the new RNG state>
```
At the end of the scope, the original RNG state will be restored
automatically.
Parameters
----------
inp : int or numpy.random.SeedSequence
The starting information for the new RNG state.
If it is an integer, a new `SeedSequence` will be generated from it.
"""
def __init__(self, inp):
if not isinstance(inp, np.random.SeedSequence):
inp = np.random.SeedSequence(inp)
self._sseq = inp
def __enter__(self):
self._depth = len(_sseq)
push_sseq(self._sseq)
def __exit__(self, exc_type, exc_value, tb):
pop_sseq()
if self._depth != len(_sseq):
raise RuntimeError("inconsistent RNG usage detected")
return exc_type is None
......@@ -20,11 +20,17 @@ import numpy as np
import nifty6 as ift
def check_state_back_to_orig():
np.testing.assert_equal(len(ift.random._rng),1)
np.testing.assert_equal(len(ift.random._sseq),1)
def test_rand1():
with ift.random.Context(31):
a = ift.random.current_rng().integers(0,1000000000)
with ift.random.Context(31):
b = ift.random.current_rng().integers(0,1000000000)
check_state_back_to_orig()
np.testing.assert_equal(a,b)
......@@ -34,6 +40,7 @@ def test_rand2():
a = ift.random.current_rng().integers(0,1000000000)
with ift.random.Context(sseq[2]):
b = ift.random.current_rng().integers(0,1000000000)
check_state_back_to_orig()
np.testing.assert_equal(a,b)
......@@ -48,6 +55,7 @@ def test_rand3():
sseq = ift.random.spawn_sseq(1)
with ift.random.Context(sseq[0]):
b = ift.random.current_rng().integers(0,1000000000)
check_state_back_to_orig()
np.testing.assert_equal(a,b)
......@@ -68,9 +76,23 @@ def test_rand5():
ift.random.pop_sseq()
d = ift.random.current_rng().integers(0,1000000000)
ift.random.pop_sseq()
check_state_back_to_orig()
np.testing.assert_equal(a,b)
np.testing.assert_equal(c,d)
def test_rand5b():
with ift.random.Context(31):
a = ift.random.current_rng().integers(0,1000000000)
with ift.random.Context(31):
b = ift.random.current_rng().integers(0,1000000000)
c = ift.random.current_rng().integers(0,1000000000)
d = ift.random.current_rng().integers(0,1000000000)
check_state_back_to_orig()
np.testing.assert_equal(a,b)
np.testing.assert_equal(c,d)
def test_rand6():
ift.random.push_sseq_from_seed(31)
state = ift.random.getState()
......@@ -79,4 +101,14 @@ def test_rand6():
b = ift.random.current_rng().integers(0,1000000000)
np.testing.assert_equal(a,b)
ift.random.pop_sseq()
check_state_back_to_orig()
def test_rand6b():
with ift.random.Context(31):
state = ift.random.getState()
a = ift.random.current_rng().integers(0,1000000000)
ift.random.setState(state)
b = ift.random.current_rng().integers(0,1000000000)
np.testing.assert_equal(a,b)
check_state_back_to_orig()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment