Commit 1c1b6377 authored by Theo Steininger's avatar Theo Steininger

Merge branch 'tweak_limited_exp' into 'master'

Tweak limited exp

See merge request !193
parents 41c87296 9f86a5cb
Pipeline #17226 passed with stages
in 35 minutes and 42 seconds
......@@ -24,7 +24,7 @@ from .field import Field
__all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin',
'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log',
'conjugate', 'clipped_exp', 'limited_exp']
'conjugate', 'clipped_exp', 'limited_exp', 'limited_exp_deriv']
def _math_helper(x, function):
......@@ -101,15 +101,28 @@ def clipped_exp(x):
def limited_exp(x):
thr = 200
expthr = np.exp(thr)
return _math_helper(x, lambda z: _limited_exp_helper(z, thr, expthr))
return _math_helper(x, _limited_exp_helper)
def _limited_exp_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = ((1.-thr) + x)*np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
def _limited_exp_helper(x, thr, expthr):
mask = (x > thr)
result = np.exp(x)
result[mask] = ((1-thr) + x[mask])*expthr
def limited_exp_deriv(x):
return _math_helper(x, _limited_exp_deriv_helper)
def _limited_exp_deriv_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = np.empty_like(x)
result[mask] = np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment