Commit cf99895f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'clipped_exp' into 'NIFTy_5'

Add clipped_exp

See merge request ift/nifty-dev!128
parents 1f84d24e 988f368b
......@@ -34,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -303,6 +304,10 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr(_current_module, f, func(f))
def clipped_exp(a):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis))
def from_object(object, dtype, copy, set_locked):
if dtype is None:
dtype = object.dtype
......
......@@ -33,7 +33,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"]
ntask = 1
rank = 0
......@@ -149,3 +150,7 @@ def absmax(arr):
def norm(arr, ord=2):
return np.linalg.norm(arr.reshape(-1), ord=ord)
def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300))
......@@ -634,6 +634,9 @@ class Field(object):
def positive_tanh(self):
return 0.5*(1.+self.tanh())
def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val))
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
......
......@@ -183,6 +183,10 @@ class Linearization(object):
tmp = self._val.exp()
return self.new(tmp, makeOp(tmp)(self._jac))
def clipped_exp(self):
tmp = self._val.clipped_exp()
return self.new(tmp, makeOp(tmp)(self._jac))
def log(self):
tmp = self._val.log()
return self.new(tmp, makeOp(1./self._val)(self._jac))
......
......@@ -292,7 +292,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
for f in ["sqrt", "exp", "log", "tanh", "clipped_exp"]:
def func(f):
def func2(self):
fu = getattr(Field, f)
......
......@@ -99,7 +99,7 @@ class Operator(NiftyMetaBase()):
return self.__class__.__name__
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", 'clipped_exp']:
def func(f):
def func2(self):
fa = _FunctionApplier(self.target, f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment