numpy_do.py 3.97 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19 20 21
# Data object module for NIFTy that uses simple numpy ndarrays.

import numpy as np
Philipp Arras's avatar
Philipp Arras committed
22
from numpy import empty, empty_like, exp, full, log
23
from numpy import ndarray as data_object
Philipp Arras's avatar
Philipp Arras committed
24
from numpy import ones, sqrt, tanh, vdot, zeros
25 26
from numpy import sin, cos, tan, sinh, cosh, sinc
from numpy import absolute, sign
Martin Reinecke's avatar
Martin Reinecke committed
27 28
from .random import Random

Martin Reinecke's avatar
Martin Reinecke committed
29 30 31 32 33 34
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
           "empty", "zeros", "ones", "empty_like", "vdot", "exp",
           "log", "tanh", "sqrt", "from_object", "from_random",
           "local_data", "ibegin", "ibegin_from_shape", "np_allreduce_sum",
           "np_allreduce_min", "np_allreduce_max",
           "distaxis", "from_local_data", "from_global_data", "to_global_data",
Martin Reinecke's avatar
Martin Reinecke committed
35
           "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
Martin Reinecke's avatar
Martin Reinecke committed
36
           "lock", "locked", "uniform_full", "to_global_data_rw",
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
37
           "ensure_not_distributed", "ensure_default_distributed",
38
           "clipped_exp", "hardplus", "sin", "cos", "tan", "sinh",
39
           "cosh", "absolute", "sign", "sinc"]
Martin Reinecke's avatar
Martin Reinecke committed
40

Martin Reinecke's avatar
Martin Reinecke committed
41 42 43 44
ntask = 1
rank = 0
master = True

45

Martin Reinecke's avatar
Martin Reinecke committed
46 47 48 49
def is_numpy():
    return True


Martin Reinecke's avatar
Martin Reinecke committed
50 51 52 53 54 55 56 57 58 59 60 61 62 63
def from_object(object, dtype, copy, set_locked):
    if dtype is None:
        dtype = object.dtype
    dtypes_equal = dtype == object.dtype
    if set_locked and dtypes_equal and locked(object):
        return object
    if not dtypes_equal and not copy:
        raise ValueError("cannot change data type without copying")
    if set_locked and not copy:
        raise ValueError("cannot lock object without copying")
    res = np.array(object, dtype=dtype, copy=copy)
    if set_locked:
        lock(res)
    return res
Martin Reinecke's avatar
Martin Reinecke committed
64 65 66 67 68


def from_random(random_type, shape, dtype=np.float64, **kwargs):
    generator_function = getattr(Random, random_type)
    return generator_function(dtype=dtype, shape=shape, **kwargs)
Martin Reinecke's avatar
Martin Reinecke committed
69

Martin Reinecke's avatar
Martin Reinecke committed
70

Martin Reinecke's avatar
Martin Reinecke committed
71 72 73 74
def local_data(arr):
    return arr


75 76 77 78
def ibegin_from_shape(glob_shape, distaxis=-1):
    return (0,)*len(glob_shape)


Martin Reinecke's avatar
Martin Reinecke committed
79 80 81 82 83 84 85 86
def ibegin(arr):
    return (0,)*arr.ndim


def np_allreduce_sum(arr):
    return arr


87 88 89 90
def np_allreduce_min(arr):
    return arr


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
91 92 93 94
def np_allreduce_max(arr):
    return arr


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
95
def distaxis(arr):
Martin Reinecke's avatar
Martin Reinecke committed
96
    return -1
Martin Reinecke's avatar
Martin Reinecke committed
97 98


Martin Reinecke's avatar
Martin Reinecke committed
99
def from_local_data(shape, arr, distaxis=-1):
Martin Reinecke's avatar
Martin Reinecke committed
100
    if tuple(shape) != arr.shape:
Martin Reinecke's avatar
Martin Reinecke committed
101 102 103 104
        raise ValueError
    return arr


105
def from_global_data(arr, sum_up=False, distaxis=-1):
Martin Reinecke's avatar
Martin Reinecke committed
106 107 108
    return arr


Martin Reinecke's avatar
Martin Reinecke committed
109
def to_global_data(arr):
Martin Reinecke's avatar
Martin Reinecke committed
110 111 112
    return arr


113 114 115 116
def to_global_data_rw(arr):
    return arr.copy()


Martin Reinecke's avatar
Martin Reinecke committed
117
def redistribute(arr, dist=None, nodist=None):
Martin Reinecke's avatar
Martin Reinecke committed
118 119 120
    return arr


Martin Reinecke's avatar
fixes  
Martin Reinecke committed
121
def default_distaxis():
Martin Reinecke's avatar
Martin Reinecke committed
122 123 124
    return -1


125
def local_shape(glob_shape, distaxis=-1):
Martin Reinecke's avatar
Martin Reinecke committed
126
    return glob_shape
127 128 129 130 131 132 133 134


def lock(arr):
    arr.flags.writeable = False


def locked(arr):
    return not arr.flags.writeable
Martin Reinecke's avatar
Martin Reinecke committed
135 136 137 138


def uniform_full(shape, fill_value, dtype=None, distaxis=-1):
    return np.broadcast_to(fill_value, shape)
Martin Reinecke's avatar
Martin Reinecke committed
139 140 141 142 143 144 145 146


def ensure_not_distributed(arr, axes):
    return arr, arr


def ensure_default_distributed(arr):
    return arr
Martin Reinecke's avatar
Martin Reinecke committed
147 148 149


def absmax(arr):
150
    return np.linalg.norm(arr.rehape(-1), ord=np.inf)
Martin Reinecke's avatar
Martin Reinecke committed
151 152 153


def norm(arr, ord=2):
154
    return np.linalg.norm(arr.reshape(-1), ord=ord)
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
155 156 157 158


def clipped_exp(arr):
    return np.exp(np.clip(arr, -300, 300))
159 160 161


def hardplus(arr):
162
    return np.clip(arr, 1e-20, None)