Commit d043574e authored by Martin Reinecke's avatar Martin Reinecke

parameterize hardplus()

parent d08d0cbd
......@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign"]
"sinh", "cosh", "sinc", "absolute", "sign", "hardplus"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -310,8 +310,8 @@ def clipped_exp(x):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis))
def hardplus(x):
return data_object(x.shape, np.clip(x.data, 1e-20, None), x.distaxis)
def hardplus(x, eps):
return data_object(x.shape, np.clip(x.data, eps, None), x.distaxis)
def from_object(object, dtype, copy, set_locked):
......
......@@ -158,5 +158,5 @@ def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300))
def hardplus(arr):
return np.clip(arr, 1e-20, None)
def hardplus(arr, eps):
return np.clip(arr, eps, None)
......@@ -637,8 +637,8 @@ class Field(object):
def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val))
def hardplus(self):
return Field(self._domain, dobj.hardplus(self._val))
def hardplus(self, eps):
return Field(self._domain, dobj.hardplus(self._val, eps))
def one_over(self):
return 1/self
......
......@@ -187,9 +187,9 @@ class Linearization(object):
tmp = self._val.clipped_exp()
return self.new(tmp, makeOp(tmp)(self._jac))
def hardplus(self):
tmp = self._val.hardplus()
tmp2 = makeOp(1.-(tmp == 1e-20))
def hardplus(self, eps):
tmp = self._val.hardplus(eps)
tmp2 = makeOp(1.-(tmp == eps))
return self.new(tmp, tmp2(self._jac))
def sin(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment