Commit 603f40d8 authored by Philipp Haim's avatar Philipp Haim
Browse files

Merge remote-tracking branch 'origin/fix_from_global_data' into normalized_amplitudes_DomT

parents 47dbd48f de358ce5
Pipeline #63868 passed with stages
in 9 minutes and 44 seconds
......@@ -19,7 +19,6 @@ from .multi_field import MultiField
from .operators.operator import Operator
from .operators.adder import Adder
from .operators.log1p import Log1p
from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
......@@ -32,7 +32,7 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "transpose", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"tanh", "conjugate", "sin", "cos", "tan", "log10",
"tanh", "conjugate", "sin", "cos", "tan", "log10", "log1p", "expm1",
"sinh", "cosh", "sinc", "absolute", "sign", "clip"]
......@@ -297,7 +297,8 @@ def _math_helper(x, function, out):
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate", "sin", "cos", "tan",
"sinh", "cosh", "sinc", "absolute", "sign", "log10"]:
"sinh", "cosh", "sinc", "absolute", "sign", "log10", "log1p",
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
......@@ -396,6 +397,8 @@ def from_global_data(arr, sum_up=False, distaxis=0):
raise TypeError
if sum_up:
arr = np_allreduce_sum(arr)
if arr.ndim == 0:
distaxis = -1
if distaxis == -1:
return data_object(arr.shape, arr, distaxis)
lo, hi = _shareRange(arr.shape[distaxis], ntask, rank)
......@@ -22,7 +22,7 @@ from numpy import ndarray as data_object
from numpy import empty, empty_like, ones, zeros, full
from numpy import absolute, sign, clip, vdot
from numpy import sin, cos, sinh, cosh, tan, tanh
from numpy import exp, log, log10, sqrt, sinc
from numpy import exp, log, log10, sqrt, sinc, log1p, expm1
from .random import Random
......@@ -35,8 +35,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"redistribute", "default_distaxis", "is_numpy", "absmax", "norm",
"lock", "locked", "uniform_full", "to_global_data_rw",
"ensure_not_distributed", "ensure_default_distributed",
"clip", "sin", "cos", "tan", "sinh",
"cosh", "absolute", "sign", "sinc", "log10"]
"clip", "sin", "cos", "tan", "sinh", "cosh",
"absolute", "sign", "sinc", "log10", "log1p", "expm1"]
ntask = 1
rank = 0
......@@ -663,9 +663,8 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "log10", "tanh",
"sin", "cos", "tan", "cosh", "sinh",
"absolute", "sinc", "sign"]:
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
return Field(self._domain, getattr(dobj, f)(self.val))
......@@ -335,6 +335,16 @@ class Linearization(object):
tmp2 = 1. / (self._val * np.log(10))
return, makeOp(tmp2)(self._jac))
def log1p(self):
tmp = self._val.log1p()
tmp2 = 1. / (1. + self._val)
return, makeOp(tmp2)(self.jac))
def expm1(self):
tmp = self._val.expm1()
tmp2 = self._val.exp()
return, makeOp(tmp2)(self.jac))
def sinh(self):
tmp = self._val.sinh()
tmp2 = self._val.cosh()
......@@ -338,7 +338,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "tanh"]:
for f in ["sqrt", "exp", "log", "log1p", "expm1", "tanh"]:
def func(f):
def func2(self):
fu = getattr(Field, f)
......@@ -28,16 +28,17 @@ class ContractionOperator(LinearOperator):
This Operator sums up a field with is defined on a :class:`DomainTuple`
to a :class:`DomainTuple` which contains the former as a subset.
to a :class:`DomainTuple` which is a subset of the former.
domain : Domain, tuple of Domain or DomainTuple
spaces : int or tuple of int
spaces : None, int or tuple of int
The elements of "domain" which are contracted.
If `None`, everything is contracted
weight : int, default=0
If nonzero, the fields defined on self.domain are weighted with the
specified power.
specified power along the submdomains which are contracted.
def __init__(self, domain, spaces, weight=0):
......@@ -250,20 +250,18 @@ class InverseGammaLikelihood(EnergyOperator):
class StudentTEnergy(EnergyOperator):
"""Computes likelihood energy of expected event frequency constrained by
event data.
"""Computes likelihood energy corresponding to Student's t-distribution.
.. math ::
E(f) = -\\log \\text{Bernoulli}(d|f)
= -d^\\dagger \\log f - (1-d)^\\dagger \\log(1-f),
E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
= \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
where f is a field defined on `d.domain` with the expected
frequencies of events.
where f is a field defined on `domain`.
d : Field
Data field with events (1) or non-events (0).
domain : `Domain` or `DomainTuple`
Domain of the operator
theta : Scalar
Degree of freedom parameter for the student t distribution
......@@ -271,12 +269,10 @@ class StudentTEnergy(EnergyOperator):
def __init__(self, domain, theta):
self._domain = DomainTuple.make(domain)
self._theta = theta
from .log1p import Log1p
self._l1p = Log1p(domain)
def apply(self, x):
v = ((self._theta+1)/2)*self._l1p(x**2/self._theta).sum()
v = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
if not isinstance(x, Linearization):
return Field.scalar(v)
if not x.want_metric:
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2019 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..field import Field
from ..multi_field import MultiField
from .operator import Operator
from .diagonal_operator import DiagonalOperator
from ..linearization import Linearization
from ..sugar import from_local_data
from numpy import log1p
class Log1p(Operator):
"""computes x -> log(1+x)
def __init__(self, dom):
self._domain = dom
self._target = dom
def apply(self, x):
lin = isinstance(x, Linearization)
xval = x.val if lin else x
xlval = xval.local_data
res = from_local_data(xval.domain, log1p(xlval))
if not lin:
return res
jac = DiagonalOperator(1/(1+xval))
return, jac@x.jac)
......@@ -391,7 +391,7 @@ _current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "log10", "tanh", "sigmoid",
"conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'sinc']:
'absolute', 'one_over', 'sinc', 'log1p', 'expm1']:
def func(f):
def func2(x):
from .linearization import Linearization
......@@ -54,7 +54,7 @@ def test_special_gradients():
@pmp('f', [
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'one_over', 'sigmoid', 'log10'
'absolute', 'one_over', 'sigmoid', 'log10', 'log1p', "expm1"
def test_actual_gradients(f):
dom = ift.UnstructuredDomain((1,))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment