Commit f0a24f5e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 0e63c553
......@@ -34,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -303,6 +304,10 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr(_current_module, f, func(f))
def clipped_exp(a):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis)
def from_object(object, dtype, copy, set_locked):
if dtype is None:
dtype = object.dtype
......
......@@ -33,7 +33,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"]
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"]
ntask = 1
rank = 0
......@@ -149,3 +150,7 @@ def absmax(arr):
def norm(arr, ord=2):
return np.linalg.norm(arr.reshape(-1), ord=ord)
def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300))
......@@ -634,6 +634,9 @@ class Field(object):
def positive_tanh(self):
return 0.5*(1.+self.tanh())
def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val))
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
......@@ -675,9 +678,3 @@ for f in ["sqrt", "exp", "log", "tanh"]:
return Field(self._domain, getattr(dobj, f)(self.val))
return func2
setattr(Field, f, func(f))
def func2(self):
np.clip(self.val, -300, 300, out=self.val)
return Field(self._domain, getattr(dobj, 'exp')(self.val))
setattr(Field, 'clipped_exp', func2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment