Commit 4aa2cb82 authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

added a number of local nonlinear functions

parent eff96636
...@@ -22,7 +22,8 @@ import numpy as np ...@@ -22,7 +22,8 @@ import numpy as np
from numpy import empty, empty_like, exp, full, log from numpy import empty, empty_like, exp, full, log
from numpy import ndarray as data_object from numpy import ndarray as data_object
from numpy import ones, sqrt, tanh, vdot, zeros from numpy import ones, sqrt, tanh, vdot, zeros
from numpy import sin, cos, tan, sinh, cosh, sinc
from numpy import absolute, sign
from .random import Random from .random import Random
__all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
...@@ -34,7 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -34,7 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw", "lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clipped_exp"] "clipped_exp", "hardplus", "sin", "cos", "tan", "sinh",
"cosh","absolute", "sign", "sinc"]
ntask = 1 ntask = 1
rank = 0 rank = 0
...@@ -154,3 +156,7 @@ def norm(arr, ord=2): ...@@ -154,3 +156,7 @@ def norm(arr, ord=2):
def clipped_exp(arr): def clipped_exp(arr):
return np.exp(np.clip(arr, -300, 300)) return np.exp(np.clip(arr, -300, 300))
def hardplus(arr):
return np.clip(arr, 1e-20, None)
\ No newline at end of file
...@@ -631,12 +631,18 @@ class Field(object): ...@@ -631,12 +631,18 @@ class Field(object):
def flexible_addsub(self, other, neg): def flexible_addsub(self, other, neg):
return self-other if neg else self+other return self-other if neg else self+other
def positive_tanh(self): def sigmoid(self):
return 0.5*(1.+self.tanh()) return 0.5*(1.+self.tanh())
def clipped_exp(self): def clipped_exp(self):
return Field(self._domain, dobj.clipped_exp(self._val)) return Field(self._domain, dobj.clipped_exp(self._val))
def hardplus(self):
return Field(self._domain, dobj.hardplus(self._val))
def one_over(self):
return 1/self
def _binary_op(self, other, op): def _binary_op(self, other, op):
# if other is a field, make sure that the domains match # if other is a field, make sure that the domains match
f = getattr(self._val, op) f = getattr(self._val, op)
...@@ -672,7 +678,9 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -672,7 +678,9 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]: for f in ["sqrt", "exp", "log", "tanh",
"sin", "cos", "tan", "cosh", "sinh",
"absolute", "sinc", "sign"]:
def func(f): def func(f):
def func2(self): def func2(self):
return Field(self._domain, getattr(dobj, f)(self.val)) return Field(self._domain, getattr(dobj, f)(self.val))
......
...@@ -187,19 +187,64 @@ class Linearization(object): ...@@ -187,19 +187,64 @@ class Linearization(object):
tmp = self._val.clipped_exp() tmp = self._val.clipped_exp()
return self.new(tmp, makeOp(tmp)(self._jac)) return self.new(tmp, makeOp(tmp)(self._jac))
def hardplus(self):
tmp = self._val.hardplus()
tmp2 = makeOp(1.-(tmp==1e-20))
return self.new(tmp, tmp2(self._jac))
def sin(self):
tmp = self._val.sin()
tmp2 = self._val.cos()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cos(self):
tmp = self._val.cos()
tmp2 = - self._val.sin()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tan(self):
tmp = self._val.tan()
tmp2 = 1./(self._val.cos()**2)
return self.new(tmp, makeOp(tmp2)(self._jac))
def sinc(self):
tmp = self._val.sinc()
tmp2 = (self._val.cos()-tmp)/self._val
return self.new(tmp, makeOp(tmp2)(self._jac))
def log(self): def log(self):
tmp = self._val.log() tmp = self._val.log()
return self.new(tmp, makeOp(1./self._val)(self._jac)) return self.new(tmp, makeOp(1./self._val)(self._jac))
def sinh(self):
tmp = self._val.sinh()
tmp2 = self._val.cosh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cosh(self):
tmp = self._val.cosh()
tmp2 = self._val.sinh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tanh(self): def tanh(self):
tmp = self._val.tanh() tmp = self._val.tanh()
return self.new(tmp, makeOp(1.-tmp**2)(self._jac)) return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
def positive_tanh(self): def sigmoid(self):
tmp = self._val.tanh() tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp) tmp2 = 0.5*(1.+tmp)
return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def absolute(self):
tmp = self._val.absolute()
tmp2 = self._val.sign()
return self.new(tmp, makeOp(tmp2)(self._jac))
def one_over(self):
tmp = 1./self._val
tmp2 = - tmp/self._val
return self.new(tmp, makeOp(tmp2)(self._jac))
def add_metric(self, metric): def add_metric(self, metric):
return self.new(self._val, self._jac, metric) return self.new(self._val, self._jac, metric)
......
...@@ -107,7 +107,9 @@ class Operator(NiftyMetaBase()): ...@@ -107,7 +107,9 @@ class Operator(NiftyMetaBase()):
return self.__class__.__name__ return self.__class__.__name__
for f in ["sqrt", "exp", "log", "tanh", "positive_tanh", 'clipped_exp']: for f in ["sqrt", "exp", "log", "tanh", "sigmoid",
'clipped_exp', 'hardplus', 'sin', 'cos', 'tan',
'sinh', 'cosh', 'absolute', 'sinc', 'one_over']:
def func(f): def func(f):
def func2(self): def func2(self):
fa = _FunctionApplier(self.target, f) fa = _FunctionApplier(self.target, f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment