Commit 104dc2d7 authored by Martin Reinecke's avatar Martin Reinecke

- tweak limited_exp to avoid triggering overflows

- add limited_exp_deriv
parent 20edb351
Pipeline #16692 passed with stage
in 29 minutes and 7 seconds
......@@ -23,7 +23,7 @@ from nifty.field import Field
__all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin',
'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log',
'conjugate', 'clipped_exp', 'limited_exp']
'conjugate', 'clipped_exp', 'limited_exp', 'limited_exp_deriv']
def _math_helper(x, function):
......@@ -100,15 +100,28 @@ def clipped_exp(x):
def limited_exp(x):
thr = 200
expthr = np.exp(thr)
return _math_helper(x, lambda z: _limited_exp_helper(z, thr, expthr))
return _math_helper(x, _limited_exp_helper)
def _limited_exp_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = ((1.-thr) + x)*np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
def _limited_exp_helper(x, thr, expthr):
mask = (x > thr)
result = np.exp(x)
result[mask] = ((1-thr) + x[mask])*expthr
def limited_exp_deriv(x):
return _math_helper(x, _limited_exp_deriv_helper)
def _limited_exp_deriv_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = np.empty_like(x)
result[mask] = np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment