Commit 93833052 authored by Martin Reinecke's avatar Martin Reinecke

hardplus->clip

parent d043574e
...@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw", "lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp", "tanh", "conjugate", "sin", "cos", "tan", "clipped_exp", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "hardplus"] "sinh", "cosh", "sinc", "absolute", "sign", "clip"]
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
ntask = _comm.Get_size() ntask = _comm.Get_size()
...@@ -216,6 +216,9 @@ class data_object(object): ...@@ -216,6 +216,9 @@ class data_object(object):
else: else:
return data_object(self._shape, tval, self._distaxis) return data_object(self._shape, tval, self._distaxis)
def clip(self, min=None, max=None):
return data_object(self._shape, np.clip(self._data, min, max))
def __neg__(self): def __neg__(self):
return data_object(self._shape, -self._data, self._distaxis) return data_object(self._shape, -self._data, self._distaxis)
...@@ -310,8 +313,8 @@ def clipped_exp(x): ...@@ -310,8 +313,8 @@ def clipped_exp(x):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis)) return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis))
def hardplus(x, eps): def clip(x, a_min=None, a_max=None):
return data_object(x.shape, np.clip(x.data, eps, None), x.distaxis) return data_object(x.shape, np.clip(x.data, a_min, a_max), x.distaxis)
def from_object(object, dtype, copy, set_locked): def from_object(object, dtype, copy, set_locked):
......
...@@ -23,7 +23,7 @@ from numpy import empty, empty_like, exp, full, log ...@@ -23,7 +23,7 @@ from numpy import empty, empty_like, exp, full, log
from numpy import ndarray as data_object from numpy import ndarray as data_object
from numpy import ones, sqrt, tanh, vdot, zeros from numpy import ones, sqrt, tanh, vdot, zeros
from numpy import sin, cos, tan, sinh, cosh, sinc from numpy import sin, cos, tan, sinh, cosh, sinc
from numpy import absolute, sign from numpy import absolute, sign, clip
from .random import Random from .random import Random
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
...@@ -35,7 +35,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -35,7 +35,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw", "lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp", "hardplus", "sin", "cos", "tan", "sinh", "clipped_exp", "clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc"] "cosh", "absolute", "sign", "sinc"]
ntask = 1 ntask = 1
...@@ -156,7 +156,3 @@ def norm(arr, ord=2): ...@@ -156,7 +156,3 @@ def norm(arr, ord=2):
def clipped_exp(arr): def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300)) return np.exp(np.clip(arr, -300, 300))
def hardplus(arr, eps):
return np.clip(arr, eps, None)
...@@ -637,8 +637,8 @@ class Field(object): ...@@ -637,8 +637,8 @@ class Field(object):
def clipped_exp(self): def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val)) return Field(self._domain, dobj.clipped_exp(self._val))
def hardplus(self, eps): def clip(self, min=None, max=None):
return Field(self._domain, dobj.hardplus(self._val, eps)) return Field(self._domain, dobj.clip(self._val, min, max))
def one_over(self): def one_over(self):
return 1/self return 1/self
......
...@@ -187,8 +187,8 @@ class Linearization(object): ...@@ -187,8 +187,8 @@ class Linearization(object):
tmp = self._val.clipped_exp() tmp = self._val.clipped_exp()
return self.new(tmp, makeOp(tmp)(self._jac)) return self.new(tmp, makeOp(tmp)(self._jac))
def hardplus(self, eps): def clip(self, min=None, max=None):
tmp = self._val.hardplus(eps) tmp = self._val.clip(min, max)
tmp2 = makeOp(1.-(tmp == eps)) tmp2 = makeOp(1.-(tmp == eps))
return self.new(tmp, tmp2(self._jac)) return self.new(tmp, tmp2(self._jac))
......
...@@ -78,6 +78,11 @@ class Operator(NiftyMetaBase()): ...@@ -78,6 +78,11 @@ class Operator(NiftyMetaBase()):
return NotImplemented return NotImplemented
return _OpChain.make((_PowerOp(self.target, power), self)) return _OpChain.make((_PowerOp(self.target, power), self))
def clip(self, min=None, max=None):
if min is None and max is None:
return self
return _OpChain.make((_Clipper(sef.target, min, max), self))
def apply(self, x): def apply(self, x):
raise NotImplementedError raise NotImplementedError
...@@ -108,7 +113,7 @@ class Operator(NiftyMetaBase()): ...@@ -108,7 +113,7 @@ class Operator(NiftyMetaBase()):
for f in ["sqrt", "exp", "log", "tanh", "sigmoid", for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
'clipped_exp', 'hardplus', 'sin', 'cos', 'tan', 'clipped_exp', 'sin', 'cos', 'tan',
'sinh', 'cosh', 'absolute', 'sinc', 'one_over']: 'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
def func(f): def func(f):
def func2(self): def func2(self):
...@@ -129,6 +134,18 @@ class _FunctionApplier(Operator): ...@@ -129,6 +134,18 @@ class _FunctionApplier(Operator):
return getattr(x, self._funcname)() return getattr(x, self._funcname)()
class _Clipper(Operator):
def __init__(self, domain, min=None, max=None):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._min = min
self._max = max
def apply(self, x):
self._check_input(x)
return x.clip(self._min, self._max)
class _PowerOp(Operator): class _PowerOp(Operator):
def __init__(self, domain, power): def __init__(self, domain, power):
from ..sugar import makeDomain from ..sugar import makeDomain
......
...@@ -39,7 +39,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator', ...@@ -39,7 +39,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'full', 'from_global_data', 'from_local_data', 'full', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid', 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'hardplus', 'sinc', 'absolute', 'one_over', 'clip', 'sinc',
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union', 'conjugate', 'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain'] 'get_default_codomain']
...@@ -261,7 +261,7 @@ _current_module = sys.modules[__name__] ...@@ -261,7 +261,7 @@ _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "sigmoid", for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
"conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh', "conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'hardplus', 'sinc']: 'absolute', 'one_over', 'sinc']:
def func(f): def func(f):
def func2(x): def func2(x):
from .linearization import Linearization from .linearization import Linearization
...@@ -273,6 +273,11 @@ for f in ["sqrt", "exp", "log", "tanh", "sigmoid", ...@@ -273,6 +273,11 @@ for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
return func2 return func2
setattr(_current_module, f, func(f)) setattr(_current_module, f, func(f))
def clip(a, a_min=None, a_max=None):
return a.clip(a_min, a_max)
def get_default_codomain(domainoid, space=None): def get_default_codomain(domainoid, space=None):
"""For `RGSpace`, returns the harmonic partner domain. """For `RGSpace`, returns the harmonic partner domain.
For `DomainTuple`, returns a copy of the object in which the domain For `DomainTuple`, returns a copy of the object in which the domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment