numpy_do.py 2.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
20
21
22
# Data object module for NIFTy that uses simple numpy ndarrays.

import numpy as np
from numpy import ndarray as data_object
23
from numpy import full, empty, empty_like, sqrt, ones, zeros, vdot, \
Martin Reinecke's avatar
Martin Reinecke committed
24
                  exp, log, tanh
Martin Reinecke's avatar
Martin Reinecke committed
25
26
from .random import Random

Martin Reinecke's avatar
Martin Reinecke committed
27
28
29
30
ntask = 1
rank = 0
master = True

31

Martin Reinecke's avatar
Martin Reinecke committed
32
33
34
35
def is_numpy():
    return True


Martin Reinecke's avatar
Martin Reinecke committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def from_object(object, dtype, copy, set_locked):
    if dtype is None:
        dtype = object.dtype
    dtypes_equal = dtype == object.dtype
    if set_locked and dtypes_equal and locked(object):
        return object
    if not dtypes_equal and not copy:
        raise ValueError("cannot change data type without copying")
    if set_locked and not copy:
        raise ValueError("cannot lock object without copying")
    res = np.array(object, dtype=dtype, copy=copy)
    if set_locked:
        lock(res)
    return res
Martin Reinecke's avatar
Martin Reinecke committed
50
51
52
53
54


def from_random(random_type, shape, dtype=np.float64, **kwargs):
    generator_function = getattr(Random, random_type)
    return generator_function(dtype=dtype, shape=shape, **kwargs)
Martin Reinecke's avatar
Martin Reinecke committed
55

Martin Reinecke's avatar
Martin Reinecke committed
56

Martin Reinecke's avatar
Martin Reinecke committed
57
58
59
60
def local_data(arr):
    return arr


61
62
63
64
def ibegin_from_shape(glob_shape, distaxis=-1):
    return (0,)*len(glob_shape)


Martin Reinecke's avatar
Martin Reinecke committed
65
66
67
68
69
70
71
72
def ibegin(arr):
    return (0,)*arr.ndim


def np_allreduce_sum(arr):
    return arr


73
74
75
76
def np_allreduce_min(arr):
    return arr


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
77
def distaxis(arr):
Martin Reinecke's avatar
Martin Reinecke committed
78
    return -1
Martin Reinecke's avatar
Martin Reinecke committed
79
80


Martin Reinecke's avatar
Martin Reinecke committed
81
def from_local_data(shape, arr, distaxis=-1):
Martin Reinecke's avatar
Martin Reinecke committed
82
    if tuple(shape) != arr.shape:
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
86
        raise ValueError
    return arr


87
def from_global_data(arr, sum_up=False, distaxis=-1):
Martin Reinecke's avatar
Martin Reinecke committed
88
89
90
    return arr


Martin Reinecke's avatar
Martin Reinecke committed
91
def to_global_data(arr):
Martin Reinecke's avatar
Martin Reinecke committed
92
93
94
    return arr


Martin Reinecke's avatar
Martin Reinecke committed
95
def redistribute(arr, dist=None, nodist=None):
Martin Reinecke's avatar
Martin Reinecke committed
96
97
98
    return arr


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
99
def default_distaxis():
Martin Reinecke's avatar
Martin Reinecke committed
100
101
102
    return -1


103
def local_shape(glob_shape, distaxis=-1):
Martin Reinecke's avatar
Martin Reinecke committed
104
    return glob_shape
105
106
107
108
109
110
111
112


def lock(arr):
    arr.flags.writeable = False


def locked(arr):
    return not arr.flags.writeable