Commit 8b2f500c authored by Reimar H Leike's avatar Reimar H Leike
Browse files

build in log1p as a nonlinearity instead of as an operator

parent 66ccf5ea
Pipeline #63712 passed with stages
in 8 minutes and 12 seconds
...@@ -20,7 +20,6 @@ from .multi_field import MultiField ...@@ -20,7 +20,6 @@ from .multi_field import MultiField
from .operators.operator import Operator from .operators.operator import Operator
from .operators.adder import Adder from .operators.adder import Adder
from .operators.log1p import Log1p
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
......
...@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm", "redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw", "lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"tanh", "conjugate", "sin", "cos", "tan", "log10", "tanh", "conjugate", "sin", "cos", "tan", "log10", "log1p",
"sinh", "cosh", "sinc", "absolute", "sign", "clip"] "sinh", "cosh", "sinc", "absolute", "sign", "clip"]
_comm = MPI.COMM_WORLD _comm = MPI.COMM_WORLD
...@@ -297,7 +297,7 @@ def _math_helper(x, function, out): ...@@ -297,7 +297,7 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__] _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan", for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "log10"]: "sinh", "cosh", "sinc", "absolute", "sign", "log10", "log1p"]:
def func(f): def func(f):
def func2(x, out=None): def func2(x, out=None):
return _math_helper(x, f, out) return _math_helper(x, f, out)
......
...@@ -22,7 +22,7 @@ from numpy import ndarray as data_object ...@@ -22,7 +22,7 @@ from numpy import ndarray as data_object
from numpy import empty, empty_like, ones, zeros, full from numpy import empty, empty_like, ones, zeros, full
from numpy import absolute, sign, clip, vdot from numpy import absolute, sign, clip, vdot
from numpy import sin, cos, sinh, cosh, tan, tanh from numpy import sin, cos, sinh, cosh, tan, tanh
from numpy import exp, log, log10, sqrt, sinc from numpy import exp, log, log10, sqrt, sinc, log1p
from .random import Random from .random import Random
...@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full", ...@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock", "locked", "uniform_full", "to_global_data_rw", "lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed", "ensure_not_distributed", "ensure_default_distributed",
"clip", "sin", "cos", "tan", "sinh", "clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc", "log10"] "cosh", "absolute", "sign", "sinc", "log10", "log1p"]
ntask = 1 ntask = 1
rank = 0 rank = 0
......
...@@ -663,7 +663,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -663,7 +663,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "log10", "tanh", for f in ["sqrt", "exp", "log", "log10", "log1p", "tanh",
"sin", "cos", "tan", "cosh", "sinh", "sin", "cos", "tan", "cosh", "sinh",
"absolute", "sinc", "sign"]: "absolute", "sinc", "sign"]:
def func(f): def func(f):
......
...@@ -335,6 +335,12 @@ class Linearization(object): ...@@ -335,6 +335,12 @@ class Linearization(object):
tmp2 = 1. / (self._val * np.log(10)) tmp2 = 1. / (self._val * np.log(10))
return self.new(tmp, makeOp(tmp2)(self._jac)) return self.new(tmp, makeOp(tmp2)(self._jac))
def log1p(self):
xval = self.val
res = xval.log1p()
jac = makeOp(1. / (1. + xval))
return self.new(res, jac @ self.jac)
def sinh(self): def sinh(self):
tmp = self._val.sinh() tmp = self._val.sinh()
tmp2 = self._val.cosh() tmp2 = self._val.cosh()
......
...@@ -338,7 +338,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -338,7 +338,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr(MultiField, op, func(op)) setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]: for f in ["sqrt", "exp", "log", "log1p", "tanh"]:
def func(f): def func(f):
def func2(self): def func2(self):
fu = getattr(Field, f) fu = getattr(Field, f)
......
...@@ -269,12 +269,10 @@ class StudentTEnergy(EnergyOperator): ...@@ -269,12 +269,10 @@ class StudentTEnergy(EnergyOperator):
def __init__(self, domain, theta): def __init__(self, domain, theta):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._theta = theta self._theta = theta
from .log1p import Log1p
self._l1p = Log1p(domain)
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
v = ((self._theta+1)/2)*self._l1p(x**2/self._theta).sum() v = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(v) return Field.scalar(v)
if not x.want_metric: if not x.want_metric:
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..field import Field
from ..multi_field import MultiField
from .operator import Operator
from .diagonal_operator import DiagonalOperator
from ..linearization import Linearization
from ..sugar import from_local_data
from numpy import log1p
class Log1p(Operator):
"""computes x -> log(1+x)
"""
def __init__(self, dom):
self._domain = dom
self._target = dom
def apply(self, x):
lin = isinstance(x, Linearization)
xval = x.val if lin else x
xlval = xval.local_data
res = from_local_data(xval.domain, log1p(xlval))
if not lin:
return res
jac = DiagonalOperator(1/(1+xval))
return x.new(res, jac@x.jac)
...@@ -54,7 +54,7 @@ def test_special_gradients(): ...@@ -54,7 +54,7 @@ def test_special_gradients():
@pmp('f', [ @pmp('f', [
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh', 'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'one_over', 'sigmoid', 'log10' 'absolute', 'one_over', 'sigmoid', 'log10', 'log1p'
]) ])
def test_actual_gradients(f): def test_actual_gradients(f):
dom = ift.UnstructuredDomain((1,)) dom = ift.UnstructuredDomain((1,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment