Commit 93833052 authored by Martin Reinecke's avatar Martin Reinecke

hardplus->clip

parent d043574e
......@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "hardplus"]
"sinh", "cosh", "sinc", "absolute", "sign", "clip"]
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......@@ -216,6 +216,9 @@ class data_object(object):
else:
return data_object(self._shape, tval, self._distaxis)
def clip(self, min=None, max=None):
return data_object(self._shape, np.clip(self._data, min, max))
def __neg__(self):
return data_object(self._shape, -self._data, self._distaxis)
......@@ -310,8 +313,8 @@ def clipped_exp(x):
return data_object(x.shape, np.exp(np.clip(x.data, -300, 300), x.distaxis))
def hardplus(x, eps):
return data_object(x.shape, np.clip(x.data, eps, None), x.distaxis)
def clip(x, a_min=None, a_max=None):
return data_object(x.shape, np.clip(x.data, a_min, a_max), x.distaxis)
def from_object(object, dtype, copy, set_locked):
......
......@@ -23,7 +23,7 @@ from numpy import empty, empty_like, exp, full, log
from numpy import ndarray as data_object
from numpy import ones, sqrt, tanh, vdot, zeros
from numpy import sin, cos, tan, sinh, cosh, sinc
from numpy import absolute, sign
from numpy import absolute, sign, clip
from .random import Random
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
......@@ -35,7 +35,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clipped_exp", "hardplus", "sin", "cos", "tan", "sinh",
"clipped_exp", "clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc"]
ntask = 1
......@@ -156,7 +156,3 @@ def norm(arr, ord=2):
def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300))
def hardplus(arr, eps):
return np.clip(arr, eps, None)
......@@ -637,8 +637,8 @@ class Field(object):
def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val))
def hardplus(self, eps):
return Field(self._domain, dobj.hardplus(self._val, eps))
def clip(self, min=None, max=None):
return Field(self._domain, dobj.clip(self._val, min, max))
def one_over(self):
return 1/self
......
......@@ -187,8 +187,8 @@ class Linearization(object):
tmp = self._val.clipped_exp()
return self.new(tmp, makeOp(tmp)(self._jac))
def hardplus(self, eps):
tmp = self._val.hardplus(eps)
def clip(self, min=None, max=None):
tmp = self._val.clip(min, max)
tmp2 = makeOp(1.-(tmp == eps))
return self.new(tmp, tmp2(self._jac))
......
......@@ -78,6 +78,11 @@ class Operator(NiftyMetaBase()):
return NotImplemented
return _OpChain.make((_PowerOp(self.target, power), self))
def clip(self, min=None, max=None):
if min is None and max is None:
return self
return _OpChain.make((_Clipper(sef.target, min, max), self))
def apply(self, x):
raise NotImplementedError
......@@ -108,7 +113,7 @@ class Operator(NiftyMetaBase()):
for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
'clipped_exp', 'hardplus', 'sin', 'cos', 'tan',
'clipped_exp', 'sin', 'cos', 'tan',
'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
def func(f):
def func2(self):
......@@ -129,6 +134,18 @@ class _FunctionApplier(Operator):
return getattr(x, self._funcname)()
class _Clipper(Operator):
def __init__(self, domain, min=None, max=None):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._min = min
self._max = max
def apply(self, x):
self._check_input(x)
return x.clip(self._min, self._max)
class _PowerOp(Operator):
def __init__(self, domain, power):
from ..sugar import makeDomain
......
......@@ -39,7 +39,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'full', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid',
'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'hardplus', 'sinc',
'absolute', 'one_over', 'clip', 'sinc',
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain']
......@@ -261,7 +261,7 @@ _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
"conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'hardplus', 'sinc']:
'absolute', 'one_over', 'sinc']:
def func(f):
def func2(x):
from .linearization import Linearization
......@@ -273,6 +273,11 @@ for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
return func2
setattr(_current_module, f, func(f))
def clip(a, a_min=None, a_max=None):
return a.clip(a_min, a_max)
def get_default_codomain(domainoid, space=None):
"""For `RGSpace`, returns the harmonic partner domain.
For `DomainTuple`, returns a copy of the object in which the domain
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment