Commit 8b2f500c authored by Reimar H Leike's avatar Reimar H Leike
Browse files

build in log1p as a nonlinearity instead of as an operator

parent 66ccf5ea
Pipeline #63712 passed with stages
in 8 minutes and 12 seconds
......@@ -20,7 +20,6 @@ from .multi_field import MultiField
from .operators.operator import Operator
from .operators.adder import Adder
from .operators.log1p import Log1p
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
......
......@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"tanh", "conjugate", "sin", "cos", "tan", "log10",
"tanh", "conjugate", "sin", "cos", "tan", "log10", "log1p",
"sinh", "cosh", "sinc", "absolute", "sign", "clip"]
_comm = MPI.COMM_WORLD
......@@ -297,7 +297,7 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "log10"]:
"sinh", "cosh", "sinc", "absolute", "sign", "log10", "log1p"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
......
......@@ -22,7 +22,7 @@ from numpy import ndarray as data_object
from numpy import empty, empty_like, ones, zeros, full
from numpy import absolute, sign, clip, vdot
from numpy import sin, cos, sinh, cosh, tan, tanh
from numpy import exp, log, log10, sqrt, sinc
from numpy import exp, log, log10, sqrt, sinc, log1p
from .random import Random
......@@ -36,7 +36,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc", "log10"]
"cosh", "absolute", "sign", "sinc", "log10", "log1p"]
ntask = 1
rank = 0
......
......@@ -663,7 +663,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "log10", "tanh",
for f in ["sqrt", "exp", "log", "log10", "log1p", "tanh",
"sin", "cos", "tan", "cosh", "sinh",
"absolute", "sinc", "sign"]:
def func(f):
......
......@@ -335,6 +335,12 @@ class Linearization(object):
tmp2 = 1. / (self._val * np.log(10))
return self.new(tmp, makeOp(tmp2)(self._jac))
def log1p(self):
xval = self.val
res = xval.log1p()
jac = makeOp(1. / (1. + xval))
return self.new(res, jac @ self.jac)
def sinh(self):
tmp = self._val.sinh()
tmp2 = self._val.cosh()
......
......@@ -338,7 +338,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
for f in ["sqrt", "exp", "log", "log1p", "tanh"]:
def func(f):
def func2(self):
fu = getattr(Field, f)
......
......@@ -269,12 +269,10 @@ class StudentTEnergy(EnergyOperator):
def __init__(self, domain, theta):
self._domain = DomainTuple.make(domain)
self._theta = theta
from .log1p import Log1p
self._l1p = Log1p(domain)
def apply(self, x):
self._check_input(x)
v = ((self._theta+1)/2)*self._l1p(x**2/self._theta).sum()
v = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..field import Field
from ..multi_field import MultiField
from .operator import Operator
from .diagonal_operator import DiagonalOperator
from ..linearization import Linearization
from ..sugar import from_local_data
from numpy import log1p
class Log1p(Operator):
"""computes x -> log(1+x)
"""
def __init__(self, dom):
self._domain = dom
self._target = dom
def apply(self, x):
lin = isinstance(x, Linearization)
xval = x.val if lin else x
xlval = xval.local_data
res = from_local_data(xval.domain, log1p(xlval))
if not lin:
return res
jac = DiagonalOperator(1/(1+xval))
return x.new(res, jac@x.jac)
......@@ -54,7 +54,7 @@ def test_special_gradients():
@pmp('f', [
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'one_over', 'sigmoid', 'log10'
'absolute', 'one_over', 'sigmoid', 'log10', 'log1p'
])
def test_actual_gradients(f):
dom = ift.UnstructuredDomain((1,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment