Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
8b2f500c
Commit
8b2f500c
authored
Nov 15, 2019
by
Reimar H Leike
Browse files
build in log1p as a nonlinearity instead of as an operator
parent
66ccf5ea
Changes
9
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
8b2f500c
...
...
@@ -20,7 +20,6 @@ from .multi_field import MultiField
from
.operators.operator
import
Operator
from
.operators.adder
import
Adder
from
.operators.log1p
import
Log1p
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.distributors
import
DOFDistributor
,
PowerDistributor
from
.operators.domain_tuple_field_inserter
import
DomainTupleFieldInserter
...
...
nifty5/data_objects/distributed_do.py
View file @
8b2f500c
...
...
@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"transpose"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
,
"tanh"
,
"conjugate"
,
"sin"
,
"cos"
,
"tan"
,
"log10"
,
"tanh"
,
"conjugate"
,
"sin"
,
"cos"
,
"tan"
,
"log10"
,
"log1p"
,
"sinh"
,
"cosh"
,
"sinc"
,
"absolute"
,
"sign"
,
"clip"
]
_comm
=
MPI
.
COMM_WORLD
...
...
@@ -297,7 +297,7 @@ def _math_helper(x, function, out):
_current_module
=
sys
.
modules
[
__name__
]
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"conjugate"
,
"sin"
,
"cos"
,
"tan"
,
"sinh"
,
"cosh"
,
"sinc"
,
"absolute"
,
"sign"
,
"log10"
]:
"sinh"
,
"cosh"
,
"sinc"
,
"absolute"
,
"sign"
,
"log10"
,
"log1p"
]:
def
func
(
f
):
def
func2
(
x
,
out
=
None
):
return
_math_helper
(
x
,
f
,
out
)
...
...
nifty5/data_objects/numpy_do.py
View file @
8b2f500c
...
...
@@ -22,7 +22,7 @@ from numpy import ndarray as data_object
from
numpy
import
empty
,
empty_like
,
ones
,
zeros
,
full
from
numpy
import
absolute
,
sign
,
clip
,
vdot
from
numpy
import
sin
,
cos
,
sinh
,
cosh
,
tan
,
tanh
from
numpy
import
exp
,
log
,
log10
,
sqrt
,
sinc
from
numpy
import
exp
,
log
,
log10
,
sqrt
,
sinc
,
log1p
from
.random
import
Random
...
...
@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock"
,
"locked"
,
"uniform_full"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
,
"clip"
,
"sin"
,
"cos"
,
"tan"
,
"sinh"
,
"cosh"
,
"absolute"
,
"sign"
,
"sinc"
,
"log10"
]
"cosh"
,
"absolute"
,
"sign"
,
"sinc"
,
"log10"
,
"log1p"
]
ntask
=
1
rank
=
0
...
...
nifty5/field.py
View file @
8b2f500c
...
...
@@ -663,7 +663,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return
func2
setattr
(
Field
,
op
,
func
(
op
))
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"log10"
,
"tanh"
,
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"log10"
,
"log1p"
,
"tanh"
,
"sin"
,
"cos"
,
"tan"
,
"cosh"
,
"sinh"
,
"absolute"
,
"sinc"
,
"sign"
]:
def
func
(
f
):
...
...
nifty5/linearization.py
View file @
8b2f500c
...
...
@@ -335,6 +335,12 @@ class Linearization(object):
tmp2
=
1.
/
(
self
.
_val
*
np
.
log
(
10
))
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
log1p
(
self
):
xval
=
self
.
val
res
=
xval
.
log1p
()
jac
=
makeOp
(
1.
/
(
1.
+
xval
))
return
self
.
new
(
res
,
jac
@
self
.
jac
)
def
sinh
(
self
):
tmp
=
self
.
_val
.
sinh
()
tmp2
=
self
.
_val
.
cosh
()
...
...
nifty5/multi_field.py
View file @
8b2f500c
...
...
@@ -338,7 +338,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr
(
MultiField
,
op
,
func
(
op
))
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
]:
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"log1p"
,
"tanh"
]:
def
func
(
f
):
def
func2
(
self
):
fu
=
getattr
(
Field
,
f
)
...
...
nifty5/operators/energy_operators.py
View file @
8b2f500c
...
...
@@ -269,12 +269,10 @@ class StudentTEnergy(EnergyOperator):
def
__init__
(
self
,
domain
,
theta
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_theta
=
theta
from
.log1p
import
Log1p
self
.
_l1p
=
Log1p
(
domain
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
v
=
((
self
.
_theta
+
1
)
/
2
)
*
self
.
_l1p
(
x
**
2
/
self
.
_theta
).
sum
()
v
=
((
self
.
_theta
+
1
)
/
2
)
*
(
x
**
2
/
self
.
_theta
).
log1p
().
sum
()
if
not
isinstance
(
x
,
Linearization
):
return
Field
.
scalar
(
v
)
if
not
x
.
want_metric
:
...
...
nifty5/operators/log1p.py
deleted
100644 → 0
View file @
66ccf5ea
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
..field
import
Field
from
..multi_field
import
MultiField
from
.operator
import
Operator
from
.diagonal_operator
import
DiagonalOperator
from
..linearization
import
Linearization
from
..sugar
import
from_local_data
from
numpy
import
log1p
class
Log1p
(
Operator
):
"""computes x -> log(1+x)
"""
def
__init__
(
self
,
dom
):
self
.
_domain
=
dom
self
.
_target
=
dom
def
apply
(
self
,
x
):
lin
=
isinstance
(
x
,
Linearization
)
xval
=
x
.
val
if
lin
else
x
xlval
=
xval
.
local_data
res
=
from_local_data
(
xval
.
domain
,
log1p
(
xlval
))
if
not
lin
:
return
res
jac
=
DiagonalOperator
(
1
/
(
1
+
xval
))
return
x
.
new
(
res
,
jac
@
x
.
jac
)
test/test_linearization.py
View file @
8b2f500c
...
...
@@ -54,7 +54,7 @@ def test_special_gradients():
@
pmp
(
'f'
,
[
'log'
,
'exp'
,
'sqrt'
,
'sin'
,
'cos'
,
'tan'
,
'sinc'
,
'sinh'
,
'cosh'
,
'tanh'
,
'absolute'
,
'one_over'
,
'sigmoid'
,
'log10'
'absolute'
,
'one_over'
,
'sigmoid'
,
'log10'
,
'log1p'
])
def
test_actual_gradients
(
f
):
dom
=
ift
.
UnstructuredDomain
((
1
,))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment