Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
4aa2cb82
Commit
4aa2cb82
authored
Dec 15, 2018
by
Jakob Knollmueller
Browse files
added a number of local nonlinear functions
parent
eff96636
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty5/data_objects/numpy_do.py
View file @
4aa2cb82
...
...
@@ -22,7 +22,8 @@ import numpy as np
from
numpy
import
empty
,
empty_like
,
exp
,
full
,
log
from
numpy
import
ndarray
as
data_object
from
numpy
import
ones
,
sqrt
,
tanh
,
vdot
,
zeros
from
numpy
import
sin
,
cos
,
tan
,
sinh
,
cosh
,
sinc
from
numpy
import
absolute
,
sign
from
.random
import
Random
__all__
=
[
"ntask"
,
"rank"
,
"master"
,
"local_shape"
,
"data_object"
,
"full"
,
...
...
@@ -34,7 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
,
"clipped_exp"
]
"clipped_exp"
,
"hardplus"
,
"sin"
,
"cos"
,
"tan"
,
"sinh"
,
"cosh"
,
"absolute"
,
"sign"
,
"sinc"
]
ntask
=
1
rank
=
0
...
...
@@ -154,3 +156,7 @@ def norm(arr, ord=2):
def
clipped_exp
(
arr
):
return
np
.
exp
(
np
.
clip
(
arr
,
-
300
,
300
))
def
hardplus
(
arr
):
return
np
.
clip
(
arr
,
1e-20
,
None
)
\ No newline at end of file
nifty5/field.py
View file @
4aa2cb82
...
...
@@ -631,12 +631,18 @@ class Field(object):
def
flexible_addsub
(
self
,
other
,
neg
):
return
self
-
other
if
neg
else
self
+
other
def
positive_tanh
(
self
):
def
sigmoid
(
self
):
return
0.5
*
(
1.
+
self
.
tanh
())
def
clipped_exp
(
self
):
return
Field
(
self
.
_domain
,
dobj
.
clipped_exp
(
self
.
_val
))
def
hardplus
(
self
):
return
Field
(
self
.
_domain
,
dobj
.
hardplus
(
self
.
_val
))
def
one_over
(
self
):
return
1
/
self
def
_binary_op
(
self
,
other
,
op
):
# if other is a field, make sure that the domains match
f
=
getattr
(
self
.
_val
,
op
)
...
...
@@ -672,7 +678,9 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return
func2
setattr
(
Field
,
op
,
func
(
op
))
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
]:
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"sin"
,
"cos"
,
"tan"
,
"cosh"
,
"sinh"
,
"absolute"
,
"sinc"
,
"sign"
]:
def
func
(
f
):
def
func2
(
self
):
return
Field
(
self
.
_domain
,
getattr
(
dobj
,
f
)(
self
.
val
))
...
...
nifty5/linearization.py
View file @
4aa2cb82
...
...
@@ -187,19 +187,64 @@ class Linearization(object):
tmp
=
self
.
_val
.
clipped_exp
()
return
self
.
new
(
tmp
,
makeOp
(
tmp
)(
self
.
_jac
))
def
hardplus
(
self
):
tmp
=
self
.
_val
.
hardplus
()
tmp2
=
makeOp
(
1.
-
(
tmp
==
1e-20
))
return
self
.
new
(
tmp
,
tmp2
(
self
.
_jac
))
def
sin
(
self
):
tmp
=
self
.
_val
.
sin
()
tmp2
=
self
.
_val
.
cos
()
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
cos
(
self
):
tmp
=
self
.
_val
.
cos
()
tmp2
=
-
self
.
_val
.
sin
()
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
tan
(
self
):
tmp
=
self
.
_val
.
tan
()
tmp2
=
1.
/
(
self
.
_val
.
cos
()
**
2
)
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
sinc
(
self
):
tmp
=
self
.
_val
.
sinc
()
tmp2
=
(
self
.
_val
.
cos
()
-
tmp
)
/
self
.
_val
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
log
(
self
):
tmp
=
self
.
_val
.
log
()
return
self
.
new
(
tmp
,
makeOp
(
1.
/
self
.
_val
)(
self
.
_jac
))
def
sinh
(
self
):
tmp
=
self
.
_val
.
sinh
()
tmp2
=
self
.
_val
.
cosh
()
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
cosh
(
self
):
tmp
=
self
.
_val
.
cosh
()
tmp2
=
self
.
_val
.
sinh
()
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
tanh
(
self
):
tmp
=
self
.
_val
.
tanh
()
return
self
.
new
(
tmp
,
makeOp
(
1.
-
tmp
**
2
)(
self
.
_jac
))
def
positive_tanh
(
self
):
def
sigmoid
(
self
):
tmp
=
self
.
_val
.
tanh
()
tmp2
=
0.5
*
(
1.
+
tmp
)
return
self
.
new
(
tmp2
,
makeOp
(
0.5
*
(
1.
-
tmp
**
2
))(
self
.
_jac
))
def
absolute
(
self
):
tmp
=
self
.
_val
.
absolute
()
tmp2
=
self
.
_val
.
sign
()
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
one_over
(
self
):
tmp
=
1.
/
self
.
_val
tmp2
=
-
tmp
/
self
.
_val
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
add_metric
(
self
,
metric
):
return
self
.
new
(
self
.
_val
,
self
.
_jac
,
metric
)
...
...
nifty5/operators/operator.py
View file @
4aa2cb82
...
...
@@ -107,7 +107,9 @@ class Operator(NiftyMetaBase()):
return
self
.
__class__
.
__name__
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"positive_tanh"
,
'clipped_exp'
]:
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"sigmoid"
,
'clipped_exp'
,
'hardplus'
,
'sin'
,
'cos'
,
'tan'
,
'sinh'
,
'cosh'
,
'absolute'
,
'sinc'
,
'one_over'
]:
def
func
(
f
):
def
func2
(
self
):
fa
=
_FunctionApplier
(
self
.
target
,
f
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment