random.py 3.16 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15 16 17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from builtins import object
20 21
import numpy as np

22

23 24
class Random(object):
    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
25
    def pm1(dtype, shape):
26
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
27
            x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
28
            x = x[np.random.randint(4, size=shape)]
29
        else:
Martin Reinecke's avatar
Martin Reinecke committed
30 31
            x = 2*np.random.randint(2, size=shape) - 1
        return x.astype(dtype, copy=False)
32 33

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
34
    def normal(dtype, shape, mean=0., std=1.):
Martin Reinecke's avatar
Martin Reinecke committed
35 36 37
        if not (np.issubdtype(dtype, np.floating) or
                np.issubdtype(dtype, np.complexfloating)):
            raise TypeError("dtype must be float or complex")
38 39 40 41
        if not np.isscalar(mean) or not np.isscalar(std):
            raise TypeError("mean and std must be scalars")
        if np.issubdtype(type(std), np.complexfloating):
            raise TypeError("std must not be complex")
Martin Reinecke's avatar
Martin Reinecke committed
42 43
        if ((not np.issubdtype(dtype, np.complexfloating)) and
                np.issubdtype(type(mean), np.complexfloating)):
44
            raise TypeError("mean must not be complex for a real result field")
45
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
46
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
47 48
            x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape)
            x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape)
49
        else:
Martin Reinecke's avatar
Martin Reinecke committed
50
            x = np.random.normal(mean, std, shape).astype(dtype, copy=False)
51 52 53
        return x

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
54
    def uniform(dtype, shape, low=0., high=1.):
55 56
        if not np.isscalar(low) or not np.isscalar(high):
            raise TypeError("low and high must be scalars")
Martin Reinecke's avatar
Martin Reinecke committed
57
        if (np.issubdtype(type(low), np.complexfloating) or
58 59
                np.issubdtype(type(high), np.complexfloating)):
            raise TypeError("low and high must not be complex")
60
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
bug fix  
Martin Reinecke committed
61
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
62 63
            x.real = np.random.uniform(low, high, shape)
            x.imag = np.random.uniform(low, high, shape)
64
        elif np.issubdtype(dtype, np.integer):
Martin Reinecke's avatar
Martin Reinecke committed
65
            if not (np.issubdtype(type(low), np.integer) and
66 67 68
                    np.issubdtype(type(high), np.integer)):
                raise TypeError("low and high must be integer")
            x = np.random.randint(low, high+1, shape)
69
        else:
Martin Reinecke's avatar
Martin Reinecke committed
70
            x = np.random.uniform(low, high, shape)
Martin Reinecke's avatar
Martin Reinecke committed
71
        return x.astype(dtype, copy=False)