random.py 3.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Philipp Arras's avatar
Philipp Arras committed
17

18
19
import numpy as np

20

21
22
class Random(object):
    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
23
    def pm1(dtype, shape):
24
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
25
            x = np.array([1+0j, 0+1j, -1+0j, 0-1j], dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
26
            x = x[np.random.randint(4, size=shape)]
27
        else:
Martin Reinecke's avatar
Martin Reinecke committed
28
29
            x = 2*np.random.randint(2, size=shape) - 1
        return x.astype(dtype, copy=False)
30
31

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
32
    def normal(dtype, shape, mean=0., std=1.):
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35
        if not (np.issubdtype(dtype, np.floating) or
                np.issubdtype(dtype, np.complexfloating)):
            raise TypeError("dtype must be float or complex")
36
37
38
39
        if not np.isscalar(mean) or not np.isscalar(std):
            raise TypeError("mean and std must be scalars")
        if np.issubdtype(type(std), np.complexfloating):
            raise TypeError("std must not be complex")
Martin Reinecke's avatar
Martin Reinecke committed
40
41
        if ((not np.issubdtype(dtype, np.complexfloating)) and
                np.issubdtype(type(mean), np.complexfloating)):
42
            raise TypeError("mean must not be complex for a real result field")
43
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
Martin Reinecke committed
44
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
45
46
            x.real = np.random.normal(mean.real, std*np.sqrt(0.5), shape)
            x.imag = np.random.normal(mean.imag, std*np.sqrt(0.5), shape)
47
        else:
Martin Reinecke's avatar
Martin Reinecke committed
48
            x = np.random.normal(mean, std, shape).astype(dtype, copy=False)
49
50
51
        return x

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
52
    def uniform(dtype, shape, low=0., high=1.):
53
54
        if not np.isscalar(low) or not np.isscalar(high):
            raise TypeError("low and high must be scalars")
Martin Reinecke's avatar
Martin Reinecke committed
55
        if (np.issubdtype(type(low), np.complexfloating) or
56
57
                np.issubdtype(type(high), np.complexfloating)):
            raise TypeError("low and high must not be complex")
58
        if np.issubdtype(dtype, np.complexfloating):
Martin Reinecke's avatar
bug fix  
Martin Reinecke committed
59
            x = np.empty(shape, dtype=dtype)
Martin Reinecke's avatar
Martin Reinecke committed
60
61
            x.real = np.random.uniform(low, high, shape)
            x.imag = np.random.uniform(low, high, shape)
62
        elif np.issubdtype(dtype, np.integer):
Martin Reinecke's avatar
Martin Reinecke committed
63
            if not (np.issubdtype(type(low), np.integer) and
64
65
66
                    np.issubdtype(type(high), np.integer)):
                raise TypeError("low and high must be integer")
            x = np.random.randint(low, high+1, shape)
67
        else:
Martin Reinecke's avatar
Martin Reinecke committed
68
            x = np.random.uniform(low, high, shape)
Martin Reinecke's avatar
Martin Reinecke committed
69
        return x.astype(dtype, copy=False)