diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py index 886a4ceaa7362f7ae73d90a92221536e45e359d1..0054f3f40f52bf8aeb3e46d8922bf73090224bfa 100644 --- a/demos/getting_started_1.py +++ b/demos/getting_started_1.py @@ -46,7 +46,7 @@ def make_random_mask(): if __name__ == '__main__': - np.random.seed(42) + ift.random.init(42) # Choose space on which the signal field is defined if len(sys.argv) == 2: diff --git a/nifty6/__init__.py b/nifty6/__init__.py index 7ba7583e984d35422c16ca02ea48fdcfbeef6207..e2e8ec76b78d439354f3b6ef83a1b96ec2d21157 100644 --- a/nifty6/__init__.py +++ b/nifty6/__init__.py @@ -1,5 +1,7 @@ from .version import __version__ +from . import random + from .domains.domain import Domain from .domains.structured_domain import StructuredDomain from .domains.unstructured_domain import UnstructuredDomain diff --git a/nifty6/field.py b/nifty6/field.py index 306dad50301323e290411e81c0e88bb518b3ad73..d325ff8ee850529cfd1575050f42dbdfe9a13174 100644 --- a/nifty6/field.py +++ b/nifty6/field.py @@ -140,9 +140,9 @@ class Field(object): Field The newly created Field. """ - from .random import Random + from . import random domain = DomainTuple.make(domain) - generator_function = getattr(Random, random_type) + generator_function = getattr(random, random_type) arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs) return Field(domain, arr) diff --git a/nifty6/random.py b/nifty6/random.py index 673bf6596444d826f149520a26759f8dd423a603..6c4a71adfd632288d77cfdd7b3b110cd2e61197b 100644 --- a/nifty6/random.py +++ b/nifty6/random.py @@ -11,59 +11,74 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # -# Copyright(C) 2013-2019 Max-Planck-Society +# Copyright(C) 2013-2020 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np +_initialized = False -class Random(object): - @staticmethod - def pm1(dtype, shape): - if np.issubdtype(dtype, np.complexfloating): - x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype) - x = x[np.random.randint(4, size=shape)] - else: - x = 2*np.random.randint(2, size=shape) - 1 - return x.astype(dtype, copy=False) +def init(seed): + global _initialized + if _initialized: + print("WARNING: re-intializing random generator") + np.random.seed(seed) + else: + _initialized = True + np.random.seed(seed) - @staticmethod - def normal(dtype, shape, mean=0., std=1.): - if not (np.issubdtype(dtype, np.floating) or - np.issubdtype(dtype, np.complexfloating)): - raise TypeError("dtype must be float or complex") - if not np.isscalar(mean) or not np.isscalar(std): - raise TypeError("mean and std must be scalars") - if np.issubdtype(type(std), np.complexfloating): - raise TypeError("std must not be complex") - if ((not np.issubdtype(dtype, np.complexfloating)) and - np.issubdtype(type(mean), np.complexfloating)): - raise TypeError("mean must not be complex for a real result field") - if np.issubdtype(dtype, np.complexfloating): - x = np.empty(shape, dtype=dtype) - x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape) - x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape) - else: - x = np.random.normal(mean, std, shape).astype(dtype, copy=False) - return x +def pm1(dtype, shape): + global _initialized + if not _initialized: + raise RuntimeError("RNG not initialized") + if np.issubdtype(dtype, np.complexfloating): + x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype) + x = x[np.random.randint(4, size=shape)] + else: + x = 2*np.random.randint(2, size=shape) - 1 + return x.astype(dtype, copy=False) - @staticmethod - def uniform(dtype, shape, low=0., high=1.): - if not np.isscalar(low) or not np.isscalar(high): - raise TypeError("low and high must be scalars") - if (np.issubdtype(type(low), np.complexfloating) or - np.issubdtype(type(high), np.complexfloating)): - raise TypeError("low and high must not be complex") - if np.issubdtype(dtype, np.complexfloating): - x = np.empty(shape, dtype=dtype) - x.real = np.random.uniform(low, high, shape) - x.imag = np.random.uniform(low, high, shape) - elif np.issubdtype(dtype, np.integer): - if not (np.issubdtype(type(low), np.integer) and - np.issubdtype(type(high), np.integer)): - raise TypeError("low and high must be integer") - x = np.random.randint(low, high+1, shape) - else: - x = np.random.uniform(low, high, shape) - return x.astype(dtype, copy=False) +def normal(dtype, shape, mean=0., std=1.): + global _initialized + if not _initialized: + raise RuntimeError("RNG not initialized") + if not (np.issubdtype(dtype, np.floating) or + np.issubdtype(dtype, np.complexfloating)): + raise TypeError("dtype must be float or complex") + if not np.isscalar(mean) or not np.isscalar(std): + raise TypeError("mean and std must be scalars") + if np.issubdtype(type(std), np.complexfloating): + raise TypeError("std must not be complex") + if ((not np.issubdtype(dtype, np.complexfloating)) and + np.issubdtype(type(mean), np.complexfloating)): + raise TypeError("mean must not be complex for a real result field") + if np.issubdtype(dtype, np.complexfloating): + x = np.empty(shape, dtype=dtype) + x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape) + x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape) + else: + x = np.random.normal(mean, std, shape).astype(dtype, copy=False) + return x + +def uniform(dtype, shape, low=0., high=1.): + global _initialized + if not _initialized: + raise RuntimeError("RNG not initialized") + if not np.isscalar(low) or not np.isscalar(high): + raise TypeError("low and high must be scalars") + if (np.issubdtype(type(low), np.complexfloating) or + np.issubdtype(type(high), np.complexfloating)): + raise TypeError("low and high must not be complex") + if np.issubdtype(dtype, np.complexfloating): + x = np.empty(shape, dtype=dtype) + x.real = np.random.uniform(low, high, shape) + x.imag = np.random.uniform(low, high, shape) + elif np.issubdtype(dtype, np.integer): + if not (np.issubdtype(type(low), np.integer) and + np.issubdtype(type(high), np.integer)): + raise TypeError("low and high must be integer") + x = np.random.randint(low, high+1, shape) + else: + x = np.random.uniform(low, high, shape) + return x.astype(dtype, copy=False)