Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
d37dc659
Commit
d37dc659
authored
Mar 06, 2020
by
Martin Reinecke
Browse files
Merge branch 'gig-energy' into 'NIFTy_6'
Gig energy See merge request
!411
parents
df116915
7a6052d7
Pipeline
#70299
passed with stages
in 15 minutes and 32 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/__init__.py
View file @
d37dc659
...
...
@@ -44,7 +44,7 @@ from .operators.value_inserter import ValueInserter
from
.operators.energy_operators
import
(
EnergyOperator
,
GaussianEnergy
,
PoissonianEnergy
,
InverseGammaLikelihood
,
BernoulliEnergy
,
StandardHamiltonian
,
AveragedEnergy
,
QuadraticFormOperator
,
Squared2NormOperator
,
StudentTEnergy
)
Squared2NormOperator
,
StudentTEnergy
,
VariableCovarianceGaussianEnergy
)
from
.operators.convolution_operators
import
FuncConvolutionOperator
from
.probing
import
probe_with_posterior_samples
,
probe_diagonal
,
\
...
...
nifty6/operators/energy_operators.py
View file @
d37dc659
...
...
@@ -19,6 +19,7 @@ import numpy as np
from
..
import
utilities
from
..domain_tuple
import
DomainTuple
from
..multi_domain
import
MultiDomain
from
..field
import
Field
from
..multi_field
import
MultiField
from
..linearization
import
Linearization
...
...
@@ -28,7 +29,7 @@ from .operator import Operator
from
.sampling_enabler
import
SamplingEnabler
from
.sandwich_operator
import
SandwichOperator
from
.scaling_operator
import
ScalingOperator
from
.simple_linear_operators
import
VdotOperator
from
.simple_linear_operators
import
VdotOperator
,
FieldAdapter
class
EnergyOperator
(
Operator
):
...
...
@@ -96,6 +97,47 @@ class QuadraticFormOperator(EnergyOperator):
return
Field
.
scalar
(
0.5
*
x
.
vdot
(
self
.
_op
(
x
)))
class
VariableCovarianceGaussianEnergy
(
EnergyOperator
):
"""Computes the negative log pdf of a Gaussian with unknown covariance.
The covariance is assumed to be diagonal.
.. math ::
E(s,D) = -
\\
log G(s, D) = 0.5 (s)^
\\
dagger D^{-1} (s) + 0.5 tr log(D),
an information energy for a Gaussian distribution with residual s and
diagonal covariance D.
The domain of this energy will be a MultiDomain with two keys,
the target will be the scalar domain.
Parameters
----------
domain : Domain, DomainTuple, tuple of Domain
domain of the residual and domain of the covariance diagonal.
residual : key
Residual key of the Gaussian.
inverse_covariance : key
Inverse covariance diagonal key of the Gaussian.
"""
def
__init__
(
self
,
domain
,
residual_key
,
inverse_covariance_key
):
self
.
_r
=
str
(
residual_key
)
self
.
_icov
=
str
(
inverse_covariance_key
)
dom
=
DomainTuple
.
make
(
domain
)
self
.
_domain
=
MultiDomain
.
make
({
self
.
_r
:
dom
,
self
.
_icov
:
dom
})
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
res0
=
x
[
self
.
_r
].
vdot
(
x
[
self
.
_r
]
*
x
[
self
.
_icov
]).
real
res1
=
x
[
self
.
_icov
].
log
().
sum
()
res
=
0.5
*
(
res0
-
res1
)
mf
=
{
self
.
_r
:
x
.
val
[
self
.
_icov
],
self
.
_icov
:
.
5
*
x
.
val
[
self
.
_icov
]
**
(
-
2
)}
metric
=
makeOp
(
MultiField
.
from_dict
(
mf
))
return
res
.
add_metric
(
SandwichOperator
.
make
(
x
.
jac
,
metric
))
class
GaussianEnergy
(
EnergyOperator
):
"""Computes a negative-log Gaussian.
...
...
test/test_energy_gradients.py
View file @
d37dc659
...
...
@@ -43,23 +43,33 @@ def field(request):
return
ift
.
MultiField
.
from_dict
({
's1'
:
s
})[
's1'
]
def
test_variablecovariancegaussian
(
field
):
dc
=
{
'a'
:
field
,
'b'
:
field
.
exp
()}
mf
=
ift
.
MultiField
.
from_dict
(
dc
)
energy
=
ift
.
VariableCovarianceGaussianEnergy
(
field
.
domain
,
'a'
,
'b'
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
mf
,
tol
=
1e-6
)
energy
(
ift
.
Linearization
.
make_var
(
mf
,
want_metric
=
True
)).
metric
.
draw_sample
()
def
test_gaussian
(
field
):
energy
=
ift
.
GaussianEnergy
(
domain
=
field
.
domain
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
)
@
pmp
(
'icov'
,
[
lambda
dom
:
ift
.
ScalingOperator
(
dom
,
1.
),
lambda
dom
:
ift
.
SandwichOperator
.
make
(
ift
.
GeometryRemover
(
dom
))])
lambda
dom
:
ift
.
SandwichOperator
.
make
(
ift
.
GeometryRemover
(
dom
))])
def
test_ScaledEnergy
(
field
,
icov
):
icov
=
icov
(
field
.
domain
)
energy
=
ift
.
GaussianEnergy
(
inverse_covariance
=
icov
)
ift
.
extra
.
check_jacobian_consistency
(
energy
.
scale
(
0.3
),
field
)
lin
=
ift
.
Linearization
.
make_var
(
field
,
want_metric
=
True
)
lin
=
ift
.
Linearization
.
make_var
(
field
,
want_metric
=
True
)
met1
=
energy
(
lin
).
metric
met2
=
energy
.
scale
(
0.3
)(
lin
).
metric
np
.
testing
.
assert_allclose
(
met1
(
field
).
val
,
met2
(
field
).
val
/
0.3
,
rtol
=
1e-12
)
np
.
testing
.
assert_allclose
(
met1
(
field
).
val
,
met2
(
field
).
val
/
0.3
,
rtol
=
1e-12
)
met2
.
draw_sample
()
def
test_studentt
(
field
):
energy
=
ift
.
StudentTEnergy
(
domain
=
field
.
domain
,
theta
=
.
5
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
,
tol
=
1e-6
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment