Skip to content
Snippets Groups Projects
Commit a0fa21d9 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'randomstate' into 'NIFTy_6'

allow reading and restoring the state of the random module

See merge request !435
parents ece9c551 3a3cd86c
Branches
Tags
1 merge request!435allow reading and restoring the state of the random module
Pipeline #71903 passed
...@@ -79,6 +79,31 @@ _sseq = [np.random.SeedSequence(42)] ...@@ -79,6 +79,31 @@ _sseq = [np.random.SeedSequence(42)]
_rng = [np.random.default_rng(_sseq[-1])] _rng = [np.random.default_rng(_sseq[-1])]
def getState():
"""Returns the full internal state of the module. Intended for pickling.
Returns
-------
state : unspecified
"""
import pickle
return pickle.dumps((_sseq, _rng))
def setState(state):
"""Restores the full internal state of the module. Intended for unpickling.
Parameters
----------
state : unspecified
Result of an earlier call to `getState`.
"""
import pickle
global _sseq, _rng
_sseq, _rng = pickle.loads(state)
def spawn_sseq(n, parent=None): def spawn_sseq(n, parent=None):
"""Returns a list of `n` SeedSequence objects which are children of `parent` """Returns a list of `n` SeedSequence objects which are children of `parent`
......
...@@ -81,3 +81,12 @@ def test_rand5(): ...@@ -81,3 +81,12 @@ def test_rand5():
ift.random.pop_sseq() ift.random.pop_sseq()
np.testing.assert_equal(a,b) np.testing.assert_equal(a,b)
np.testing.assert_equal(c,d) np.testing.assert_equal(c,d)
def test_rand6():
ift.random.push_sseq_from_seed(31)
state = ift.random.getState()
a = ift.random.current_rng().integers(0,1000000000)
ift.random.setState(state)
b = ift.random.current_rng().integers(0,1000000000)
np.testing.assert_equal(a,b)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment