Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
4231c07b
Commit
4231c07b
authored
Oct 23, 2019
by
Reimar H Leike
Browse files
added a student-t energy, a Log(1+x) nonlinearity and a test for the energy
parent
156c9d79
Pipeline
#62423
passed with stages
in 7 minutes and 5 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
4231c07b
...
...
@@ -20,6 +20,7 @@ from .multi_field import MultiField
from
.operators.operator
import
Operator
from
.operators.adder
import
Adder
from
.operators.log1p
import
Log1p
from
.operators.diagonal_operator
import
DiagonalOperator
from
.operators.distributors
import
DOFDistributor
,
PowerDistributor
from
.operators.domain_tuple_field_inserter
import
DomainTupleFieldInserter
...
...
@@ -51,7 +52,7 @@ from .operators.value_inserter import ValueInserter
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
BernoulliEnergy
,
StandardHamiltonian
,
AveragedEnergy
,
QuadraticFormOperator
,
Squared2NormOperator
)
Squared2NormOperator
,
StudentTEnergy
)
from
.operators.convolution_operators
import
FuncConvolutionOperator
from
.probing
import
probe_with_posterior_samples
,
probe_diagonal
,
\
...
...
nifty5/operators/energy_operators.py
View file @
4231c07b
...
...
@@ -27,6 +27,7 @@ from .linear_operator import LinearOperator
from
.operator
import
Operator
from
.sampling_enabler
import
SamplingEnabler
from
.sandwich_operator
import
SandwichOperator
from
.scaling_operator
import
ScalingOperator
from
.simple_linear_operators
import
VdotOperator
...
...
@@ -64,7 +65,6 @@ class Squared2NormOperator(EnergyOperator):
return
x
.
new
(
val
,
jac
)
return
Field
.
scalar
(
x
.
vdot
(
x
))
class
QuadraticFormOperator
(
EnergyOperator
):
"""Computes the L2-norm of a Field or MultiField with respect to a
specific kernel given by `endo`.
...
...
@@ -248,6 +248,43 @@ class InverseGammaLikelihood(EnergyOperator):
return
res
.
add_metric
(
metric
)
class
StudentTEnergy
(
EnergyOperator
):
"""Computes likelihood energy of expected event frequency constrained by
event data.
.. math ::
E(f) = -
\\
log
\\
text{Bernoulli}(d|f)
= -d^
\\
dagger
\\
log f - (1-d)^
\\
dagger
\\
log(1-f),
where f is a field defined on `d.domain` with the expected
frequencies of events.
Parameters
----------
d : Field
Data field with events (1) or non-events (0).
theta : Scalar
Degree of freedom parameter for the student t distribution
"""
def
__init__
(
self
,
domain
,
theta
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_theta
=
theta
from
.log1p
import
Log1p
self
.
_l1p
=
Log1p
(
domain
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
v
=
((
self
.
_theta
+
1
)
/
2
)
*
self
.
_l1p
(
x
**
2
/
self
.
_theta
).
sum
()
if
not
isinstance
(
x
,
Linearization
):
return
Field
.
scalar
(
v
)
if
not
x
.
want_metric
:
return
v
met
=
ScalingOperator
(
self
.
domain
,
(
self
.
_theta
+
1
)
/
(
self
.
_theta
+
3
))
met
=
SandwichOperator
.
make
(
x
.
jac
,
met
)
return
v
.
add_metric
(
met
)
class
BernoulliEnergy
(
EnergyOperator
):
"""Computes likelihood energy of expected event frequency constrained by
event data.
...
...
nifty5/operators/log1p.py
0 → 100644
View file @
4231c07b
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
..field
import
Field
from
..multi_field
import
MultiField
from
.operator
import
Operator
from
.diagonal_operator
import
DiagonalOperator
from
..linearization
import
Linearization
from
..sugar
import
from_local_data
from
numpy
import
log1p
class
Log1p
(
Operator
):
"""computes x -> log(1+x)
"""
def
__init__
(
self
,
dom
):
self
.
_domain
=
dom
self
.
_target
=
dom
def
apply
(
self
,
x
):
lin
=
isinstance
(
x
,
Linearization
)
xval
=
x
.
val
if
lin
else
x
xlval
=
xval
.
local_data
res
=
from_local_data
(
x
.
domain
,
log1p
(
xlval
))
if
not
lin
:
return
res
jac
=
DiagonalOperator
(
1
/
(
1
+
xval
))
return
x
.
new
(
res
,
jac
@
x
.
jac
)
test/test_energy_gradients.py
View file @
4231c07b
...
...
@@ -46,6 +46,9 @@ def test_gaussian(field):
energy
=
ift
.
GaussianEnergy
(
domain
=
field
.
domain
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
)
def
test_studentt
(
field
):
energy
=
ift
.
StudentTEnergy
(
domain
=
field
.
domain
,
theta
=
.
5
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
,
tol
=
1e-6
)
def
test_inverse_gamma
(
field
):
field
=
field
.
exp
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment