basic_arithmetics.py 3.18 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20 21
import numpy as np
from d2o import distributed_data_object
Martin Reinecke's avatar
Martin Reinecke committed
22
from .field import Field
23 24 25 26


__all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin',
           'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log',
27
           'conjugate', 'clipped_exp', 'limited_exp', 'limited_exp_deriv']
28 29


Martin Reinecke's avatar
Martin Reinecke committed
30
def _math_helper(x, function):
31
    if isinstance(x, Field):
Martin Reinecke's avatar
Martin Reinecke committed
32 33 34
        result_val = x.val.apply_scalar_function(function)
        result = x.copy_empty(dtype=result_val.dtype)
        result.val = result_val
35
    elif isinstance(x, distributed_data_object):
Martin Reinecke's avatar
Martin Reinecke committed
36
        result = x.apply_scalar_function(function, inplace=False)
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    else:
        result = function(np.asarray(x))

    return result


def cos(x):
    return _math_helper(x, np.cos)


def sin(x):
    return _math_helper(x, np.sin)


def cosh(x):
    return _math_helper(x, np.cosh)


def sinh(x):
    return _math_helper(x, np.sinh)


def tan(x):
    return _math_helper(x, np.tan)


def tanh(x):
    return _math_helper(x, np.tanh)


def arccos(x):
    return _math_helper(x, np.arccos)


def arcsin(x):
    return _math_helper(x, np.arcsin)


def arccosh(x):
    return _math_helper(x, np.arccosh)


def arcsinh(x):
    return _math_helper(x, np.arcsinh)


def arctan(x):
    return _math_helper(x, np.arctan)


def arctanh(x):
    return _math_helper(x, np.arctanh)


def sqrt(x):
    return _math_helper(x, np.sqrt)


def exp(x):
    return _math_helper(x, np.exp)


Theo Steininger's avatar
Theo Steininger committed
99 100 101 102
def clipped_exp(x):
    return _math_helper(x, lambda z: np.exp(np.minimum(200, z)))


Martin Reinecke's avatar
Martin Reinecke committed
103
def limited_exp(x):
104 105 106 107 108 109 110 111 112 113
    return _math_helper(x, _limited_exp_helper)

def _limited_exp_helper(x):
    thr = 200.
    mask = x>thr
    if np.count_nonzero(mask) == 0:
        return np.exp(x)
    result = ((1.-thr) + x)*np.exp(thr)
    result[~mask] = np.exp(x[~mask])
    return result
114

115 116 117 118 119 120 121 122 123 124 125
def limited_exp_deriv(x):
    return _math_helper(x, _limited_exp_deriv_helper)

def _limited_exp_deriv_helper(x):
    thr = 200.
    mask = x>thr
    if np.count_nonzero(mask) == 0:
        return np.exp(x)
    result = np.empty_like(x)
    result[mask] = np.exp(thr)
    result[~mask] = np.exp(x[~mask])
126 127 128
    return result


129 130 131 132 133 134 135 136 137 138 139 140 141 142
def log(x, base=None):
    result = _math_helper(x, np.log)
    if base is not None:
        result = result/log(base)

    return result


def conjugate(x):
    return _math_helper(x, np.conjugate)


def conj(x):
    return _math_helper(x, np.conjugate)