From d82f28b2425a537be50facd1279d656b53236a3b Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Thu, 19 Mar 2020 11:46:10 +0100
Subject: [PATCH] demo implementation

---
 demos/getting_started_1.py |   2 +-
 nifty6/__init__.py         |   2 +
 nifty6/field.py            |   4 +-
 nifty6/random.py           | 111 +++++++++++++++++++++----------------
 4 files changed, 68 insertions(+), 51 deletions(-)

diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py
index 886a4ceaa..0054f3f40 100644
--- a/demos/getting_started_1.py
+++ b/demos/getting_started_1.py
@@ -46,7 +46,7 @@ def make_random_mask():
 
 
 if __name__ == '__main__':
-    np.random.seed(42)
+    ift.random.init(42)
 
     # Choose space on which the signal field is defined
     if len(sys.argv) == 2:
diff --git a/nifty6/__init__.py b/nifty6/__init__.py
index 7ba7583e9..e2e8ec76b 100644
--- a/nifty6/__init__.py
+++ b/nifty6/__init__.py
@@ -1,5 +1,7 @@
 from .version import __version__
 
+from . import random
+
 from .domains.domain import Domain
 from .domains.structured_domain import StructuredDomain
 from .domains.unstructured_domain import UnstructuredDomain
diff --git a/nifty6/field.py b/nifty6/field.py
index 306dad503..d325ff8ee 100644
--- a/nifty6/field.py
+++ b/nifty6/field.py
@@ -140,9 +140,9 @@ class Field(object):
         Field
             The newly created Field.
         """
-        from .random import Random
+        from . import random
         domain = DomainTuple.make(domain)
-        generator_function = getattr(Random, random_type)
+        generator_function = getattr(random, random_type)
         arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs)
         return Field(domain, arr)
 
diff --git a/nifty6/random.py b/nifty6/random.py
index 673bf6596..6c4a71adf 100644
--- a/nifty6/random.py
+++ b/nifty6/random.py
@@ -11,59 +11,74 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 #
-# Copyright(C) 2013-2019 Max-Planck-Society
+# Copyright(C) 2013-2020 Max-Planck-Society
 #
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
 
 import numpy as np
 
+_initialized = False
 
-class Random(object):
-    @staticmethod
-    def pm1(dtype, shape):
-        if np.issubdtype(dtype, np.complexfloating):
-            x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype)
-            x = x[np.random.randint(4, size=shape)]
-        else:
-            x = 2*np.random.randint(2, size=shape) - 1
-        return x.astype(dtype, copy=False)
+def init(seed):
+    global _initialized
+    if _initialized:
+        print("WARNING: re-intializing random generator")
+        np.random.seed(seed)
+    else:
+        _initialized = True
+        np.random.seed(seed)
 
-    @staticmethod
-    def normal(dtype, shape, mean=0., std=1.):
-        if not (np.issubdtype(dtype, np.floating) or
-                np.issubdtype(dtype, np.complexfloating)):
-            raise TypeError("dtype must be float or complex")
-        if not np.isscalar(mean) or not np.isscalar(std):
-            raise TypeError("mean and std must be scalars")
-        if np.issubdtype(type(std), np.complexfloating):
-            raise TypeError("std must not be complex")
-        if ((not np.issubdtype(dtype, np.complexfloating)) and
-                np.issubdtype(type(mean), np.complexfloating)):
-            raise TypeError("mean must not be complex for a real result field")
-        if np.issubdtype(dtype, np.complexfloating):
-            x = np.empty(shape, dtype=dtype)
-            x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape)
-            x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape)
-        else:
-            x = np.random.normal(mean, std, shape).astype(dtype, copy=False)
-        return x
+def pm1(dtype, shape):
+    global _initialized
+    if not _initialized:
+        raise RuntimeError("RNG not initialized")
+    if np.issubdtype(dtype, np.complexfloating):
+        x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype)
+        x = x[np.random.randint(4, size=shape)]
+    else:
+        x = 2*np.random.randint(2, size=shape) - 1
+    return x.astype(dtype, copy=False)
 
-    @staticmethod
-    def uniform(dtype, shape, low=0., high=1.):
-        if not np.isscalar(low) or not np.isscalar(high):
-            raise TypeError("low and high must be scalars")
-        if (np.issubdtype(type(low), np.complexfloating) or
-                np.issubdtype(type(high), np.complexfloating)):
-            raise TypeError("low and high must not be complex")
-        if np.issubdtype(dtype, np.complexfloating):
-            x = np.empty(shape, dtype=dtype)
-            x.real = np.random.uniform(low, high, shape)
-            x.imag = np.random.uniform(low, high, shape)
-        elif np.issubdtype(dtype, np.integer):
-            if not (np.issubdtype(type(low), np.integer) and
-                    np.issubdtype(type(high), np.integer)):
-                raise TypeError("low and high must be integer")
-            x = np.random.randint(low, high+1, shape)
-        else:
-            x = np.random.uniform(low, high, shape)
-        return x.astype(dtype, copy=False)
+def normal(dtype, shape, mean=0., std=1.):
+    global _initialized
+    if not _initialized:
+        raise RuntimeError("RNG not initialized")
+    if not (np.issubdtype(dtype, np.floating) or
+            np.issubdtype(dtype, np.complexfloating)):
+        raise TypeError("dtype must be float or complex")
+    if not np.isscalar(mean) or not np.isscalar(std):
+        raise TypeError("mean and std must be scalars")
+    if np.issubdtype(type(std), np.complexfloating):
+        raise TypeError("std must not be complex")
+    if ((not np.issubdtype(dtype, np.complexfloating)) and
+            np.issubdtype(type(mean), np.complexfloating)):
+        raise TypeError("mean must not be complex for a real result field")
+    if np.issubdtype(dtype, np.complexfloating):
+        x = np.empty(shape, dtype=dtype)
+        x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape)
+        x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape)
+    else:
+        x = np.random.normal(mean, std, shape).astype(dtype, copy=False)
+    return x
+
+def uniform(dtype, shape, low=0., high=1.):
+    global _initialized
+    if not _initialized:
+        raise RuntimeError("RNG not initialized")
+    if not np.isscalar(low) or not np.isscalar(high):
+        raise TypeError("low and high must be scalars")
+    if (np.issubdtype(type(low), np.complexfloating) or
+            np.issubdtype(type(high), np.complexfloating)):
+        raise TypeError("low and high must not be complex")
+    if np.issubdtype(dtype, np.complexfloating):
+        x = np.empty(shape, dtype=dtype)
+        x.real = np.random.uniform(low, high, shape)
+        x.imag = np.random.uniform(low, high, shape)
+    elif np.issubdtype(dtype, np.integer):
+        if not (np.issubdtype(type(low), np.integer) and
+                np.issubdtype(type(high), np.integer)):
+            raise TypeError("low and high must be integer")
+        x = np.random.randint(low, high+1, shape)
+    else:
+        x = np.random.uniform(low, high, shape)
+    return x.astype(dtype, copy=False)
-- 
GitLab