Commit f0a24f5e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 0e63c553
...@@ -34,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -34,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis", "from_local_data", "from_global_data", "to_global_data", "distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw", "lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"] "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"]
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
...@@ -303,6 +304,10 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: ...@@ -303,6 +304,10 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr(_current_module, f, func(f)) setattr(_current_module, f, func(f))
def clipped_exp(a):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis)
def from_object(object, dtype, copy, set_locked): def from_object(object, dtype, copy, set_locked):
if dtype is None: if dtype is None:
dtype = object.dtype dtype = object.dtype
......
...@@ -33,7 +33,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -33,7 +33,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis", "from_local_data", "from_global_data", "to_global_data", "distaxis", "from_local_data", "from_global_data", "to_global_data",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw", "lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed"] "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"]
ntask = 1 ntask = 1
rank = 0 rank = 0
...@@ -149,3 +150,7 @@ def absmax(arr): ...@@ -149,3 +150,7 @@ def absmax(arr):
def norm(arr, ord=2): def norm(arr, ord=2):
return np.linalg.norm(arr.reshape(-1), ord=ord) return np.linalg.norm(arr.reshape(-1), ord=ord)
def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300))
...@@ -634,6 +634,9 @@ class Field(object): ...@@ -634,6 +634,9 @@ class Field(object):
def positive_tanh(self): def positive_tanh(self):
return 0.5*(1.+self.tanh()) return 0.5*(1.+self.tanh())
def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val))
def _binary_op(self, other, op): def _binary_op(self, other, op):
# if other is a field, make sure that the domains match # if other is a field, make sure that the domains match
f = getattr(self._val, op) f = getattr(self._val, op)
...@@ -675,9 +678,3 @@ for f in ["sqrt", "exp", "log", "tanh"]: ...@@ -675,9 +678,3 @@ for f in ["sqrt", "exp", "log", "tanh"]:
return Field(self._domain, getattr(dobj, f)(self.val)) return Field(self._domain, getattr(dobj, f)(self.val))
return func2 return func2
setattr(Field, f, func(f)) setattr(Field, f, func(f))
def func2(self):
np.clip(self.val, -300, 300, out=self.val)
return Field(self._domain, getattr(dobj, 'exp')(self.val))
setattr(Field, 'clipped_exp', func2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment