random.py 2.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from builtins import object
20
21
import numpy as np

22

23
24
class Random(object):
    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
25
    def pm1(dtype, shape):
26
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
27
            x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
28
            x = x[np.random.randint(4, size=shape)]
29
        else:
Martin Reinecke's avatar
Martin Reinecke committed
30
31
            x = 2*np.random.randint(2, size=shape) - 1
        return x.astype(dtype, copy=False)
32
33

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
34
    def normal(dtype, shape, mean=0., std=1.):
35
36
37
38
39
40
41
        if not np.isscalar(mean) or not np.isscalar(std):
            raise TypeError("mean and std must be scalars")
        if np.issubdtype(type(std), np.complexfloating):
            raise TypeError("std must not be complex")
        if ((not np.issubdtype(dtype, np.complexfloating)) and 
               np.issubdtype(type(mean), np.complexfloating)):
            raise TypeError("mean must not be complex for a real result field")
42
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
43
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
44
45
            x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape)
            x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape)
46
        else:
Martin Reinecke's avatar
Martin Reinecke committed
47
            x = np.random.normal(mean, std, shape).astype(dtype, copy=False)
48
49
50
        return x

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
51
    def uniform(dtype, shape, low=0., high=1.):
52
53
54
55
56
        if not np.isscalar(low) or not np.isscalar(high):
            raise TypeError("low and high must be scalars")
        if (np.issubdtype(type(low), np.complexfloating) or 
                np.issubdtype(type(high), np.complexfloating)):
            raise TypeError("low and high must not be complex")
57
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
bug fix    
Martin Reinecke committed
58
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
59
60
            x.real = np.random.uniform(low, high, shape)
            x.imag = np.random.uniform(low, high, shape)
61
        elif np.issubdtype(dtype, np.integer):
62
63
64
65
            if not (np.issubdtype(type(low), np.integer) and 
                    np.issubdtype(type(high), np.integer)):
                raise TypeError("low and high must be integer")
            x = np.random.randint(low, high+1, shape)
66
        else:
Martin Reinecke's avatar
Martin Reinecke committed
67
            x = np.random.uniform(low, high, shape)
Martin Reinecke's avatar
Martin Reinecke committed
68
        return x.astype(dtype, copy=False)