Commit 7ac5e74c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'expm1' into 'NIFTy_5'

Introduce expm1 by wrapping the numpy function

See merge request !374
parents f240b137 f4b90f1e
Pipeline #63897 passed with stages
in 22 minutes and 53 seconds
......@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"tanh", "conjugate", "sin", "cos", "tan", "log10", "log1p",
"tanh", "conjugate", "sin", "cos", "tan", "log10", "log1p", "expm1",
"sinh", "cosh", "sinc", "absolute", "sign", "clip"]
_comm = MPI.COMM_WORLD
......@@ -297,7 +297,8 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "log10", "log1p"]:
"sinh", "cosh", "sinc", "absolute", "sign", "log10", "log1p",
"expm1"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
......
......@@ -22,7 +22,7 @@ from numpy import ndarray as data_object
from numpy import empty, empty_like, ones, zeros, full
from numpy import absolute, sign, clip, vdot
from numpy import sin, cos, sinh, cosh, tan, tanh
from numpy import exp, log, log10, sqrt, sinc, log1p
from numpy import exp, log, log10, sqrt, sinc, log1p, expm1
from .random import Random
......@@ -35,8 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc", "log10", "log1p"]
"clip", "sin", "cos", "tan", "sinh", "cosh",
"absolute", "sign", "sinc", "log10", "log1p", "expm1"]
ntask = 1
rank = 0
......
......@@ -663,9 +663,8 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "log10", "log1p", "tanh",
"sin", "cos", "tan", "cosh", "sinh",
"absolute", "sinc", "sign"]:
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
return Field(self._domain, getattr(dobj, f)(self.val))
......
......@@ -336,10 +336,14 @@ class Linearization(object):
return self.new(tmp, makeOp(tmp2)(self._jac))
def log1p(self):
xval = self.val
res = xval.log1p()
jac = makeOp(1. / (1. + xval))
return self.new(res, jac @ self.jac)
tmp = self._val.log1p()
tmp2 = 1. / (1. + self._val)
return self.new(tmp, makeOp(tmp2)(self.jac))
def expm1(self):
tmp = self._val.expm1()
tmp2 = self._val.exp()
return self.new(tmp, makeOp(tmp2)(self.jac))
def sinh(self):
tmp = self._val.sinh()
......
......@@ -338,7 +338,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "log1p", "tanh"]:
for f in ["sqrt", "exp", "log", "log1p", "expm1", "tanh"]:
def func(f):
def func2(self):
fu = getattr(Field, f)
......
......@@ -36,7 +36,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'full', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'log10',
'absolute', 'one_over', 'clip', 'sinc',
'absolute', 'one_over', 'clip', 'sinc', "log1p", "expm1",
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain', 'single_plot']
......@@ -391,7 +391,7 @@ _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "log10", "tanh", "sigmoid",
"conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'sinc']:
'absolute', 'one_over', 'sinc', 'log1p', 'expm1']:
def func(f):
def func2(x):
from .linearization import Linearization
......
......@@ -54,7 +54,7 @@ def test_special_gradients():
@pmp('f', [
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'one_over', 'sigmoid', 'log10', 'log1p'
'absolute', 'one_over', 'sigmoid', 'log10', 'log1p', "expm1"
])
def test_actual_gradients(f):
dom = ift.UnstructuredDomain((1,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment