diff --git a/nifty5/data_objects/distributed_do.py b/nifty5/data_objects/distributed_do.py index 23dc395f40d69c13970e256b8b51a3905d80add9..629e0bace55408e6a7e76c27d95342e51390fc60 100644 --- a/nifty5/data_objects/distributed_do.py +++ b/nifty5/data_objects/distributed_do.py @@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", "lock", "locked", "uniform_full", "transpose", "to_global_data_rw", "ensure_not_distributed", "ensure_default_distributed", "clipped_exp", "tanh", "conjugate", "sin", "cos", "tan", - "sinh", "cosh", "sinc", "absolute", "sign"] + "sinh", "cosh", "sinc", "absolute", "sign", "hardplus"] _comm = MPI.COMM_WORLD ntask = _comm.Get_size() @@ -310,8 +310,8 @@ def clipped_exp(x): return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis)) -def hardplus(x): - return data_object(x.shape, np.clip(x.data, 1e-20, None), x.distaxis) +def hardplus(x, eps): + return data_object(x.shape, np.clip(x.data, eps, None), x.distaxis) def from_object(object, dtype, copy, set_locked): diff --git a/nifty5/data_objects/numpy_do.py b/nifty5/data_objects/numpy_do.py index 60dda9ad5e11b1c291ac90f4b43b6c883df287e8..ab978f4e87ae81b79ac373dcc5ac8e0c4cae8c51 100644 --- a/nifty5/data_objects/numpy_do.py +++ b/nifty5/data_objects/numpy_do.py @@ -158,5 +158,5 @@ def clipped_exp(arr): return np.exp(np.clip(arr, -300, 300)) -def hardplus(arr): - return np.clip(arr, 1e-20, None) +def hardplus(arr, eps): + return np.clip(arr, eps, None) diff --git a/nifty5/field.py b/nifty5/field.py index b5ab6f1b3bb3664e1a32c1b7dd55c4d125481a8a..52e78901f00915aee3959dfb9acea00f861101c5 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -637,8 +637,8 @@ class Field(object): def clipped_exp(self): return Field(self._domain, dobj.clipped_exp(self._val)) - def hardplus(self): - return Field(self._domain, dobj.hardplus(self._val)) + def hardplus(self, eps): + return Field(self._domain, dobj.hardplus(self._val, eps)) def one_over(self): return 1/self diff --git a/nifty5/linearization.py b/nifty5/linearization.py index a8122c6513f866d8aa56be0ffc90289a4d7f43d1..c7cf837dc5d01ab8f4c77c874fbd18ad8273092f 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -187,9 +187,9 @@ class Linearization(object): tmp = self._val.clipped_exp() return self.new(tmp, makeOp(tmp)(self._jac)) - def hardplus(self): - tmp = self._val.hardplus() - tmp2 = makeOp(1.-(tmp == 1e-20)) + def hardplus(self, eps): + tmp = self._val.hardplus(eps) + tmp2 = makeOp(1.-(tmp == eps)) return self.new(tmp, tmp2(self._jac)) def sin(self):