Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
668fc8c7
Commit
668fc8c7
authored
Mar 04, 2020
by
Reimar Leike
Browse files
added a Gaussian Energy with variabel sigma, has to be debugged
parent
df116915
Pipeline
#70181
failed with stages
in 19 minutes and 25 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/__init__.py
View file @
668fc8c7
...
...
@@ -44,7 +44,7 @@ from .operators.value_inserter import ValueInserter
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
BernoulliEnergy
,
StandardHamiltonian
,
AveragedEnergy
,
QuadraticFormOperator
,
Squared2NormOperator
,
StudentTEnergy
)
Squared2NormOperator
,
StudentTEnergy
,
VariableCovarianceGaussianEnergy
)
from
.operators.convolution_operators
import
FuncConvolutionOperator
from
.probing
import
probe_with_posterior_samples
,
probe_diagonal
,
\
...
...
nifty6/operators/energy_operators.py
View file @
668fc8c7
...
...
@@ -19,6 +19,7 @@ import numpy as np
from
..
import
utilities
from
..domain_tuple
import
DomainTuple
from
..multi_domain
import
MultiDomain
from
..field
import
Field
from
..multi_field
import
MultiField
from
..linearization
import
Linearization
...
...
@@ -28,7 +29,7 @@ from .operator import Operator
from
.sampling_enabler
import
SamplingEnabler
from
.sandwich_operator
import
SandwichOperator
from
.scaling_operator
import
ScalingOperator
from
.simple_linear_operators
import
VdotOperator
from
.simple_linear_operators
import
VdotOperator
,
FieldAdapter
class
EnergyOperator
(
Operator
):
...
...
@@ -95,6 +96,63 @@ class QuadraticFormOperator(EnergyOperator):
return
x
.
new
(
val
,
jac
)
return
Field
.
scalar
(
0.5
*
x
.
vdot
(
self
.
_op
(
x
)))
class
VariableCovarianceGaussianEnergy
(
EnergyOperator
):
"""Computes a negative-log Gaussian with unknown covariance.
Represents up to constants in :math:`m`:
.. math ::
E(f) = -
\\
log G(s, D) = 0.5 (s)^
\\
dagger D^{-1} (s),
an information energy for a Gaussian distribution with residual s and
covariance D.
Parameters
----------
residual : key
residual of the Gaussian.
inverse_covariance : key
Inverse covariance of the Gaussian.
domain : Domain, DomainTuple, tuple of Domain
Operator domain. By default it is inferred from `mean` or
`covariance` if specified
"""
def
__init__
(
self
,
domain
,
residual
,
inverse_covariance
):
self
.
_residual
=
residual
self
.
_icov
=
inverse_covariance
self
.
_domain
=
MultiDomain
.
make
({
self
.
_residual
:
domain
,
self
.
_icov
:
domain
})
self
.
_singledom
=
domain
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
lin
=
isinstance
(
x
,
Linearization
)
xval
=
x
.
val
if
lin
else
x
res
=
.
5
*
xval
[
self
.
_residual
].
vdot
(
xval
[
self
.
_residual
]
*
xval
[
self
.
_icov
])
\
-
.
5
*
xval
[
self
.
_icov
].
log
().
sum
()
if
not
lin
:
return
res
FA_res
=
FieldAdapter
(
self
.
_singledom
,
self
.
_residual
)
FA_sig
=
FieldAdapter
(
self
.
_singledom
,
self
.
_icov
)
jac_res
=
xval
[
self
.
_residual
]
*
xval
[
self
.
_icov
]
jac_res
=
VdotOperator
(
jac_res
)(
FA_res
)
jac_sig
=
.
5
*
(
xval
[
self
.
_residual
].
absolute
()
**
2
)
jac_sig
=
VdotOperator
(
jac_sig
)(
FA_sig
)
jac_sig
=
jac_sig
-
VdotOperator
(
1.
/
xval
[
self
.
_residual
])(
FA_sig
)
jac
=
(
jac_sig
+
jac_res
)(
x
.
jac
)
res
=
x
.
new
(
res
,
jac
)
if
not
x
.
want_metric
:
return
res
mf
=
{
self
.
_residual
:
xval
[
self
.
_icov
],
self
.
_icov
:.
5
*
xval
[
self
.
_icov
]
**
(
-
2
)}
mf
=
MultiField
.
from_dict
(
mf
)
metric
=
makeOp
(
mf
)
metric
=
SandwichOperator
(
x
.
jac
,
metric
)
return
res
.
add_metric
(
metric
)
class
GaussianEnergy
(
EnergyOperator
):
"""Computes a negative-log Gaussian.
...
...
test/test_energy_gradients.py
View file @
668fc8c7
...
...
@@ -42,6 +42,12 @@ def field(request):
s
=
S
.
draw_sample
()
return
ift
.
MultiField
.
from_dict
({
's1'
:
s
})[
's1'
]
def
test_variablecovariancegaussian
(
field
):
dc
=
{
'a'
:
field
,
'b'
:
field
.
exp
()}
mf
=
ift
.
MultiField
.
from_dict
(
dc
)
energy
=
ift
.
VariableCovarianceGaussianEnergy
(
field
.
domain
,
residual
=
'a'
,
inverse_covariance
=
'b'
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
mf
)
def
test_gaussian
(
field
):
energy
=
ift
.
GaussianEnergy
(
domain
=
field
.
domain
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment