Commit 06639bc9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add a context manager for the random module

parent e837fc03
......@@ -231,3 +231,16 @@ class Random(object):
else:
x = _rng[-1].uniform(low, high, shape)
return x.astype(dtype, copy=False)
class Context(object):
def __init__(self, inp):
if not isinstance(inp, np.random.SeedSequence):
inp = np.random.SeedSequence(inp)
self._sseq = inp
def __enter__(self):
push_sseq(self._sseq)
def __exit__(self, exc_type, exc_value, tb):
return exc_type is None
......@@ -21,11 +21,10 @@ import nifty6 as ift
def test_rand1():
ift.random.push_sseq_from_seed(31)
a = ift.random.current_rng().integers(0,1000000000)
ift.random.pop_sseq()
ift.random.push_sseq_from_seed(31)
b = ift.random.current_rng().integers(0,1000000000)
with ift.random.Context(31):
a = ift.random.current_rng().integers(0,1000000000)
with ift.random.Context(31):
b = ift.random.current_rng().integers(0,1000000000)
np.testing.assert_equal(a,b)
......@@ -43,21 +42,17 @@ def test_rand2():
def test_rand3():
ift.random.push_sseq_from_seed(31)
sseq = ift.random.spawn_sseq(10)
ift.random.push_sseq(sseq[2])
a = ift.random.current_rng().integers(0,1000000000)
ift.random.pop_sseq()
ift.random.pop_sseq()
ift.random.push_sseq_from_seed(31)
sseq = ift.random.spawn_sseq(1)
sseq = ift.random.spawn_sseq(1)
sseq = ift.random.spawn_sseq(1)
ift.random.push_sseq(sseq[0])
b = ift.random.current_rng().integers(0,1000000000)
ift.random.pop_sseq()
with ift.random.Context(31):
sseq = ift.random.spawn_sseq(10)
with ift.random.Context(sseq[2]):
a = ift.random.current_rng().integers(0,1000000000)
with ift.random.Context(31):
sseq = ift.random.spawn_sseq(1)
sseq = ift.random.spawn_sseq(1)
sseq = ift.random.spawn_sseq(1)
with ift.random.Context(sseq[0]):
b = ift.random.current_rng().integers(0,1000000000)
np.testing.assert_equal(a,b)
ift.random.pop_sseq()
def test_rand4():
......@@ -70,6 +65,14 @@ def test_rand4():
np.testing.assert_equal(a,b)
def test_rand4b():
with ift.random.Context(31):
a = ift.random.current_rng().integers(0,1000000000)
with ift.random.Context(31):
b = ift.random.current_rng().integers(0,1000000000)
np.testing.assert_equal(a,b)
def test_rand5():
ift.random.push_sseq_from_seed(31)
a = ift.random.current_rng().integers(0,1000000000)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment