From d82f28b2425a537be50facd1279d656b53236a3b Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Thu, 19 Mar 2020 11:46:10 +0100 Subject: [PATCH] demo implementation --- demos/getting_started_1.py | 2 +- nifty6/__init__.py | 2 + nifty6/field.py | 4 +- nifty6/random.py | 111 +++++++++++++++++++++---------------- 4 files changed, 68 insertions(+), 51 deletions(-) diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py index 886a4ceaa..0054f3f40 100644 --- a/demos/getting_started_1.py +++ b/demos/getting_started_1.py @@ -46,7 +46,7 @@ def make_random_mask(): if __name__ == '__main__': - np.random.seed(42) + ift.random.init(42) # Choose space on which the signal field is defined if len(sys.argv) == 2: diff --git a/nifty6/__init__.py b/nifty6/__init__.py index 7ba7583e9..e2e8ec76b 100644 --- a/nifty6/__init__.py +++ b/nifty6/__init__.py @@ -1,5 +1,7 @@ from .version import __version__ +from . import random + from .domains.domain import Domain from .domains.structured_domain import StructuredDomain from .domains.unstructured_domain import UnstructuredDomain diff --git a/nifty6/field.py b/nifty6/field.py index 306dad503..d325ff8ee 100644 --- a/nifty6/field.py +++ b/nifty6/field.py @@ -140,9 +140,9 @@ class Field(object): Field The newly created Field. """ - from .random import Random + from . import random domain = DomainTuple.make(domain) - generator_function = getattr(Random, random_type) + generator_function = getattr(random, random_type) arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs) return Field(domain, arr) diff --git a/nifty6/random.py b/nifty6/random.py index 673bf6596..6c4a71adf 100644 --- a/nifty6/random.py +++ b/nifty6/random.py @@ -11,59 +11,74 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # -# Copyright(C) 2013-2019 Max-Planck-Society +# Copyright(C) 2013-2020 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np +_initialized = False -class Random(object): - @staticmethod - def pm1(dtype, shape): - if np.issubdtype(dtype, np.complexfloating): - x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype) - x = x[np.random.randint(4, size=shape)] - else: - x = 2*np.random.randint(2, size=shape) - 1 - return x.astype(dtype, copy=False) +def init(seed): + global _initialized + if _initialized: + print("WARNING: re-intializing random generator") + np.random.seed(seed) + else: + _initialized = True + np.random.seed(seed) - @staticmethod - def normal(dtype, shape, mean=0., std=1.): - if not (np.issubdtype(dtype, np.floating) or - np.issubdtype(dtype, np.complexfloating)): - raise TypeError("dtype must be float or complex") - if not np.isscalar(mean) or not np.isscalar(std): - raise TypeError("mean and std must be scalars") - if np.issubdtype(type(std), np.complexfloating): - raise TypeError("std must not be complex") - if ((not np.issubdtype(dtype, np.complexfloating)) and - np.issubdtype(type(mean), np.complexfloating)): - raise TypeError("mean must not be complex for a real result field") - if np.issubdtype(dtype, np.complexfloating): - x = np.empty(shape, dtype=dtype) - x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape) - x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape) - else: - x = np.random.normal(mean, std, shape).astype(dtype, copy=False) - return x +def pm1(dtype, shape): + global _initialized + if not _initialized: + raise RuntimeError("RNG not initialized") + if np.issubdtype(dtype, np.complexfloating): + x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype) + x = x[np.random.randint(4, size=shape)] + else: + x = 2*np.random.randint(2, size=shape) - 1 + return x.astype(dtype, copy=False) - @staticmethod - def uniform(dtype, shape, low=0., high=1.): - if not np.isscalar(low) or not np.isscalar(high): - raise TypeError("low and high must be scalars") - if (np.issubdtype(type(low), np.complexfloating) or - np.issubdtype(type(high), np.complexfloating)): - raise TypeError("low and high must not be complex") - if np.issubdtype(dtype, np.complexfloating): - x = np.empty(shape, dtype=dtype) - x.real = np.random.uniform(low, high, shape) - x.imag = np.random.uniform(low, high, shape) - elif np.issubdtype(dtype, np.integer): - if not (np.issubdtype(type(low), np.integer) and - np.issubdtype(type(high), np.integer)): - raise TypeError("low and high must be integer") - x = np.random.randint(low, high+1, shape) - else: - x = np.random.uniform(low, high, shape) - return x.astype(dtype, copy=False) +def normal(dtype, shape, mean=0., std=1.): + global _initialized + if not _initialized: + raise RuntimeError("RNG not initialized") + if not (np.issubdtype(dtype, np.floating) or + np.issubdtype(dtype, np.complexfloating)): + raise TypeError("dtype must be float or complex") + if not np.isscalar(mean) or not np.isscalar(std): + raise TypeError("mean and std must be scalars") + if np.issubdtype(type(std), np.complexfloating): + raise TypeError("std must not be complex") + if ((not np.issubdtype(dtype, np.complexfloating)) and + np.issubdtype(type(mean), np.complexfloating)): + raise TypeError("mean must not be complex for a real result field") + if np.issubdtype(dtype, np.complexfloating): + x = np.empty(shape, dtype=dtype) + x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape) + x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape) + else: + x = np.random.normal(mean, std, shape).astype(dtype, copy=False) + return x + +def uniform(dtype, shape, low=0., high=1.): + global _initialized + if not _initialized: + raise RuntimeError("RNG not initialized") + if not np.isscalar(low) or not np.isscalar(high): + raise TypeError("low and high must be scalars") + if (np.issubdtype(type(low), np.complexfloating) or + np.issubdtype(type(high), np.complexfloating)): + raise TypeError("low and high must not be complex") + if np.issubdtype(dtype, np.complexfloating): + x = np.empty(shape, dtype=dtype) + x.real = np.random.uniform(low, high, shape) + x.imag = np.random.uniform(low, high, shape) + elif np.issubdtype(dtype, np.integer): + if not (np.issubdtype(type(low), np.integer) and + np.issubdtype(type(high), np.integer)): + raise TypeError("low and high must be integer") + x = np.random.randint(low, high+1, shape) + else: + x = np.random.uniform(low, high, shape) + return x.astype(dtype, copy=False) -- GitLab