Skip to content
Snippets Groups Projects
Commit 7c69cf19 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more fixes

parent 2bc45272
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -32,12 +32,15 @@ class Random(object): ...@@ -32,12 +32,15 @@ class Random(object):
@staticmethod @staticmethod
def normal(dtype, shape, mean=0., std=1.): def normal(dtype, shape, mean=0., std=1.):
if not (np.issubdtype(dtype, np.floating) or
np.issubdtype(dtype, np.complexfloating)):
raise TypeError("dtype must be float or complex")
if not np.isscalar(mean) or not np.isscalar(std): if not np.isscalar(mean) or not np.isscalar(std):
raise TypeError("mean and std must be scalars") raise TypeError("mean and std must be scalars")
if np.issubdtype(type(std), np.complexfloating): if np.issubdtype(type(std), np.complexfloating):
raise TypeError("std must not be complex") raise TypeError("std must not be complex")
if ((not np.issubdtype(dtype, np.complexfloating)) and if ((not np.issubdtype(dtype, np.complexfloating)) and
np.issubdtype(type(mean), np.complexfloating)): np.issubdtype(type(mean), np.complexfloating)):
raise TypeError("mean must not be complex for a real result field") raise TypeError("mean must not be complex for a real result field")
if np.issubdtype(dtype, np.complexfloating): if np.issubdtype(dtype, np.complexfloating):
x = np.empty(shape, dtype=dtype) x = np.empty(shape, dtype=dtype)
...@@ -51,7 +54,7 @@ class Random(object): ...@@ -51,7 +54,7 @@ class Random(object):
def uniform(dtype, shape, low=0., high=1.): def uniform(dtype, shape, low=0., high=1.):
if not np.isscalar(low) or not np.isscalar(high): if not np.isscalar(low) or not np.isscalar(high):
raise TypeError("low and high must be scalars") raise TypeError("low and high must be scalars")
if (np.issubdtype(type(low), np.complexfloating) or if (np.issubdtype(type(low), np.complexfloating) or
np.issubdtype(type(high), np.complexfloating)): np.issubdtype(type(high), np.complexfloating)):
raise TypeError("low and high must not be complex") raise TypeError("low and high must not be complex")
if np.issubdtype(dtype, np.complexfloating): if np.issubdtype(dtype, np.complexfloating):
...@@ -59,7 +62,7 @@ class Random(object): ...@@ -59,7 +62,7 @@ class Random(object):
x.real = np.random.uniform(low, high, shape) x.real = np.random.uniform(low, high, shape)
x.imag = np.random.uniform(low, high, shape) x.imag = np.random.uniform(low, high, shape)
elif np.issubdtype(dtype, np.integer): elif np.issubdtype(dtype, np.integer):
if not (np.issubdtype(type(low), np.integer) and if not (np.issubdtype(type(low), np.integer) and
np.issubdtype(type(high), np.integer)): np.issubdtype(type(high), np.integer)):
raise TypeError("low and high must be integer") raise TypeError("low and high must be integer")
x = np.random.randint(low, high+1, shape) x = np.random.randint(low, high+1, shape)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment