Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
21f0128d
Commit
21f0128d
authored
May 28, 2021
by
Philipp Frank
Committed by
Philipp Arras
May 28, 2021
Browse files
Implement Geometric KL
parent
0acc9ff7
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
ChangeLog.md
View file @
21f0128d
...
@@ -64,13 +64,29 @@ The implementation tests for nonlinear operators are now available in
...
@@ -64,13 +64,29 @@ The implementation tests for nonlinear operators are now available in
`ift.extra.check_operator()`
and for linear operators
`ift.extra.check_operator()`
and for linear operators
`ift.extra.check_linear_operator()`
.
`ift.extra.check_linear_operator()`
.
MetricGaussianKL interface
MetricGaussianKL interface
--------------------------
--------------------------
Users do not instantiate
`MetricGaussianKL`
by its constructor anymore. Rather
`mirror_samples`
is not set by default anymore.
`MetricGaussianKL.make()`
shall be used. Additionally,
`mirror_samples`
is not
set by default anymore.
GeoMetricKL
-----------
A new posterior approximation scheme, called geometric Variational Inference
(geoVI) was introduced.
`GeoMetricKL`
is analogous to
`MetricGaussianKL`
with
the exception that geoVI samples are used instead of MGVI samples. For further
details see (
<https://arxiv.org/abs/2105.10470>
).
LikelihoodOperator
------------------
A new subclass of
`EnergyOperator`
was introduced and all
`EnergyOperator`
s
that are likelihoods are now
`LikelihoodOperator`
s. A
`LikelihoodOperator`
has to implement the function
`get_transformation`
, which returns a
coordinate transformation in which the Fisher metric of the likelihood becomes
the identity matrix.
Changes since NIFTy 5
Changes since NIFTy 5
...
...
demos/getting_started_3.py
View file @
21f0128d
...
@@ -123,7 +123,7 @@ def main():
...
@@ -123,7 +123,7 @@ def main():
# Draw new samples to approximate the KL five times
# Draw new samples to approximate the KL five times
for
i
in
range
(
5
):
for
i
in
range
(
5
):
# Draw new samples and minimize KL
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
,
True
)
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
,
True
)
KL
,
convergence
=
minimizer
(
KL
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
mean
=
KL
.
position
ift
.
extra
.
minisanity
(
data
,
lambda
x
:
N
.
inverse
,
signal_response
,
ift
.
extra
.
minisanity
(
data
,
lambda
x
:
N
.
inverse
,
signal_response
,
...
...
demos/getting_started_5_mf.py
View file @
21f0128d
...
@@ -134,7 +134,7 @@ def main():
...
@@ -134,7 +134,7 @@ def main():
for
i
in
range
(
5
):
for
i
in
range
(
5
):
# Draw new samples and minimize KL
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
,
True
)
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
,
True
)
KL
,
convergence
=
minimizer
(
KL
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
mean
=
KL
.
position
...
...
demos/getting_started_density.py
View file @
21f0128d
...
@@ -94,7 +94,7 @@ if __name__ == "__main__":
...
@@ -94,7 +94,7 @@ if __name__ == "__main__":
for
i
in
range
(
5
):
for
i
in
range
(
5
):
# Draw new samples and minimize KL
# Draw new samples and minimize KL
kl
=
ift
.
MetricGaussianKL
.
make
(
mean
,
ham
,
n_samples
,
True
)
kl
=
ift
.
MetricGaussianKL
(
mean
,
ham
,
n_samples
,
True
)
kl
,
convergence
=
minimizer
(
kl
)
kl
,
convergence
=
minimizer
(
kl
)
mean
=
kl
.
position
mean
=
kl
.
position
...
...
demos/mgvi_visualized.py
View file @
21f0128d
...
@@ -36,6 +36,8 @@ import nifty7 as ift
...
@@ -36,6 +36,8 @@ import nifty7 as ift
def
main
():
def
main
():
use_geo
=
False
name
=
'GEO'
if
use_geo
else
'MGVI'
dom
=
ift
.
UnstructuredDomain
(
1
)
dom
=
ift
.
UnstructuredDomain
(
1
)
scale
=
10
scale
=
10
...
@@ -91,12 +93,16 @@ def main():
...
@@ -91,12 +93,16 @@ def main():
plt
.
figure
(
figsize
=
[
12
,
8
])
plt
.
figure
(
figsize
=
[
12
,
8
])
for
ii
in
range
(
15
):
for
ii
in
range
(
15
):
if
ii
%
3
==
0
:
if
ii
%
3
==
0
:
mgkl
=
ift
.
MetricGaussianKL
.
make
(
pos
,
ham
,
40
,
False
)
if
use_geo
:
mini_samp
=
ift
.
NewtonCG
(
ift
.
GradientNormController
(
iteration_limit
=
5
))
mgkl
=
ift
.
GeoMetricKL
(
pos
,
ham
,
100
,
mini_samp
,
False
)
else
:
mgkl
=
ift
.
MetricGaussianKL
(
pos
,
ham
,
100
,
False
)
plt
.
cla
()
plt
.
cla
()
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(),
vmin
=
1e-3
,
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(
vmin
=
1e-3
,
vmax
=
np
.
max
(
z
)),
vmax
=
np
.
max
(
z
),
cmap
=
'gist_earth_r'
,
cmap
=
'gist_earth_r'
,
extent
=
x_limits_scaled
+
y_limits
)
extent
=
x_limits_scaled
+
y_limits
)
if
ii
==
0
:
if
ii
==
0
:
cbar
=
plt
.
colorbar
()
cbar
=
plt
.
colorbar
()
cbar
.
ax
.
set_ylabel
(
'pdf'
)
cbar
.
ax
.
set_ylabel
(
'pdf'
)
...
@@ -105,12 +111,14 @@ def main():
...
@@ -105,12 +111,14 @@ def main():
samp
=
(
samp
+
pos
).
val
samp
=
(
samp
+
pos
).
val
xs
.
append
(
samp
[
'a'
])
xs
.
append
(
samp
[
'a'
])
ys
.
append
(
samp
[
'b'
])
ys
.
append
(
samp
[
'b'
])
plt
.
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
'MGVI samples'
)
plt
.
scatter
(
pos
.
val
[
'a'
]
*
scale
,
pos
.
val
[
'b'
],
label
=
'MGVI latent mean'
)
plt
.
scatter
(
np
.
array
(
map_xs
)
*
scale
,
np
.
array
(
map_ys
),
plt
.
scatter
(
np
.
array
(
map_xs
)
*
scale
,
np
.
array
(
map_ys
),
label
=
'Laplace samples'
)
label
=
'Laplace samples'
)
plt
.
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
name
+
' samples'
)
plt
.
scatter
(
pos
.
val
[
'a'
]
*
scale
,
pos
.
val
[
'b'
],
label
=
name
+
' latent mean'
)
plt
.
scatter
(
MAP
.
position
.
val
[
'a'
]
*
scale
,
MAP
.
position
.
val
[
'b'
],
plt
.
scatter
(
MAP
.
position
.
val
[
'a'
]
*
scale
,
MAP
.
position
.
val
[
'b'
],
label
=
'Maximum a posterior solution'
)
label
=
'Maximum a posterior solution'
)
plt
.
xlim
(
x_limits_scaled
)
plt
.
ylim
(
y_limits
)
plt
.
legend
()
plt
.
legend
()
plt
.
draw
()
plt
.
draw
()
plt
.
pause
(
1.0
)
plt
.
pause
(
1.0
)
...
...
src/__init__.py
View file @
21f0128d
...
@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
...
@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
from
.minimization.energy
import
Energy
from
.minimization.energy
import
Energy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.
metric_gaussian_kl
import
MetricGaussianKL
from
.minimization.
kl_energies
import
MetricGaussianKL
,
GeoMetricKL
from
.sugar
import
*
from
.sugar
import
*
...
...
src/extra.py
View file @
21f0128d
...
@@ -526,3 +526,22 @@ def _tableentries(redchisq, scmean, ndof, keylen):
...
@@ -526,3 +526,22 @@ def _tableentries(redchisq, scmean, ndof, keylen):
out
+=
f
"
{
ndof
[
kk
]:
>
11
}
"
out
+=
f
"
{
ndof
[
kk
]:
>
11
}
"
out
+=
"
\n
"
out
+=
"
\n
"
return
out
[:
-
1
]
return
out
[:
-
1
]
class
_KeyModifier
(
LinearOperator
):
def
__init__
(
self
,
domain
,
pre
):
if
not
isinstance
(
domain
,
MultiDomain
):
raise
ValueError
from
.sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_pre
=
str
(
pre
)
target
=
{
self
.
_pre
+
k
:
domain
[
k
]
for
k
in
domain
.
keys
()}
self
.
_target
=
makeDomain
(
MultiDomain
.
make
(
target
))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
res
=
{
self
.
_pre
+
k
:
x
[
k
]
for
k
in
self
.
_domain
.
keys
()}
else
:
res
=
{
k
:
x
[
self
.
_pre
+
k
]
for
k
in
self
.
_domain
.
keys
()}
return
MultiField
.
from_dict
(
res
,
domain
=
self
.
_tgt
(
mode
))
src/minimization/
metric_gaussian_kl
.py
→
src/minimization/
kl_energies
.py
View file @
21f0128d
This diff is collapsed.
Click to expand it.
src/operators/energy_operators.py
View file @
21f0128d
...
@@ -20,15 +20,18 @@ import numpy as np
...
@@ -20,15 +20,18 @@ import numpy as np
from
..
import
utilities
from
..
import
utilities
from
..domain_tuple
import
DomainTuple
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
..field
import
Field
from
..linearization
import
Linearization
from
..multi_domain
import
MultiDomain
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
..multi_field
import
MultiField
from
..sugar
import
makeDomain
,
makeOp
from
..sugar
import
makeDomain
,
makeOp
from
..utilities
import
myassert
from
..utilities
import
myassert
from
.linear_operator
import
LinearOperator
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
from
.operator
import
Operator
from
.adder
import
Adder
from
.sampling_enabler
import
SamplingDtypeSetter
,
SamplingEnabler
from
.sampling_enabler
import
SamplingDtypeSetter
,
SamplingEnabler
from
.scaling_operator
import
ScalingOperator
from
.scaling_operator
import
ScalingOperator
from
.simple_linear_operators
import
VdotOperator
from
.sandwich_operator
import
SandwichOperator
from
.simple_linear_operators
import
VdotOperator
,
FieldAdapter
def
_check_sampling_dtype
(
domain
,
dtypes
):
def
_check_sampling_dtype
(
domain
,
dtypes
):
...
@@ -78,6 +81,24 @@ class EnergyOperator(Operator):
...
@@ -78,6 +81,24 @@ class EnergyOperator(Operator):
_target
=
DomainTuple
.
scalar_domain
()
_target
=
DomainTuple
.
scalar_domain
()
class
LikelihoodOperator
(
EnergyOperator
):
"""`EnergyOperator` representing a likelihood. The input to the Operator
are the parameters of the likelihood. Unlike a general `EnergyOperator`,
the metric of a `LikelihoodOperator` is the Fisher information metric of
the likelihood.
"""
def
get_metric_at
(
self
,
x
):
"""Computes the Fisher information metric for a `LikelihoodOperator`
at `x` using the Jacobian of the coordinate transformation given by
`get_transformation`.
"""
dtp
,
f
=
self
.
get_transformation
()
ch
=
ScalingOperator
(
f
.
target
,
1.
)
if
dtp
is
not
None
:
ch
=
SamplingDtypeSetter
(
ch
,
dtp
)
return
SandwichOperator
.
make
(
f
(
Linearization
.
make_var
(
x
)).
jac
,
ch
)
class
Squared2NormOperator
(
EnergyOperator
):
class
Squared2NormOperator
(
EnergyOperator
):
"""Computes the square of the L2-norm of the output of an operator.
"""Computes the square of the L2-norm of the output of an operator.
...
@@ -126,7 +147,7 @@ class QuadraticFormOperator(EnergyOperator):
...
@@ -126,7 +147,7 @@ class QuadraticFormOperator(EnergyOperator):
return
x
.
new
(
res
,
VdotOperator
(
self
.
_op
(
x
.
val
)))
return
x
.
new
(
res
,
VdotOperator
(
self
.
_op
(
x
.
val
)))
class
VariableCovarianceGaussianEnergy
(
Energy
Operator
):
class
VariableCovarianceGaussianEnergy
(
Likelihood
Operator
):
"""Computes the negative log pdf of a Gaussian with unknown covariance.
"""Computes the negative log pdf of a Gaussian with unknown covariance.
The covariance is assumed to be diagonal.
The covariance is assumed to be diagonal.
...
@@ -152,9 +173,16 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
...
@@ -152,9 +173,16 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
sampling_dtype : np.dtype
sampling_dtype : np.dtype
Data type of the samples. Usually either 'np.float*' or 'np.complex*'
Data type of the samples. Usually either 'np.float*' or 'np.complex*'
use_full_fisher: boolean
Whether or not the proper Fisher information metric should be used as
a `metric`. If False the same approximation used in
`get_transformation` is used instead.
Default is True
"""
"""
def
__init__
(
self
,
domain
,
residual_key
,
inverse_covariance_key
,
sampling_dtype
):
def
__init__
(
self
,
domain
,
residual_key
,
inverse_covariance_key
,
sampling_dtype
,
use_full_fisher
=
True
):
self
.
_kr
=
str
(
residual_key
)
self
.
_kr
=
str
(
residual_key
)
self
.
_ki
=
str
(
inverse_covariance_key
)
self
.
_ki
=
str
(
inverse_covariance_key
)
dom
=
DomainTuple
.
make
(
domain
)
dom
=
DomainTuple
.
make
(
domain
)
...
@@ -162,6 +190,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
...
@@ -162,6 +190,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self
.
_dt
=
{
self
.
_kr
:
sampling_dtype
,
self
.
_ki
:
np
.
float64
}
self
.
_dt
=
{
self
.
_kr
:
sampling_dtype
,
self
.
_ki
:
np
.
float64
}
_check_sampling_dtype
(
self
.
_domain
,
self
.
_dt
)
_check_sampling_dtype
(
self
.
_domain
,
self
.
_dt
)
self
.
_cplx
=
_iscomplex
(
sampling_dtype
)
self
.
_cplx
=
_iscomplex
(
sampling_dtype
)
self
.
_use_fisher
=
use_full_fisher
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
...
@@ -172,10 +201,14 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
...
@@ -172,10 +201,14 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res
=
0.5
*
(
r
.
vdot
(
r
*
i
)
-
i
.
ptw
(
"log"
).
sum
())
res
=
0.5
*
(
r
.
vdot
(
r
*
i
)
-
i
.
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
if
not
x
.
want_metric
:
return
res
return
res
met
=
1.
if
self
.
_cplx
else
.
5
if
self
.
_use_fisher
:
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
*
i
.
val
**
(
-
2
)},
met
=
1.
if
self
.
_cplx
else
0.5
domain
=
self
.
_domain
)
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
*
i
.
val
**
(
-
2
)},
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
domain
=
self
.
_domain
)
met
=
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
)
else
:
met
=
self
.
get_metric_at
(
x
.
val
)
return
res
.
add_metric
(
met
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
ConstantEnergyOperator
from
.simplify_for_const
import
ConstantEnergyOperator
...
@@ -190,20 +223,30 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
...
@@ -190,20 +223,30 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res
=
GaussianEnergy
(
inverse_covariance
=
makeOp
(
cst
),
res
=
GaussianEnergy
(
inverse_covariance
=
makeOp
(
cst
),
sampling_dtype
=
dt
).
ducktape
(
self
.
_kr
)
sampling_dtype
=
dt
).
ducktape
(
self
.
_kr
)
trlog
=
cst
.
log
().
sum
().
val_rw
()
trlog
=
cst
.
log
().
sum
().
val_rw
()
if
not
_iscomplex
(
dt
)
:
if
not
self
.
_cplx
:
trlog
/=
2
trlog
/=
2
res
=
res
+
ConstantEnergyOperator
(
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
0.
)
res
=
res
+
ConstantEnergyOperator
(
0.
)
myassert
(
res
.
target
is
self
.
target
)
myassert
(
res
.
target
is
self
.
target
)
return
None
,
res
return
None
,
res
def
get_transformation
(
self
):
"""Note that for the metric of a `VariableCovarianceGaussianEnergy` no
global transformation to Euclidean space exists. A local approximation
ivoking the resudual is used instead.
"""
r
=
FieldAdapter
(
self
.
_domain
[
self
.
_kr
],
self
.
_kr
)
ivar
=
FieldAdapter
(
self
.
_domain
[
self
.
_kr
],
self
.
_ki
)
sc
=
1.
if
self
.
_cplx
else
0.5
return
self
.
_dt
,
r
.
adjoint
@
(
ivar
.
ptw
(
'sqrt'
)
*
r
)
+
ivar
.
adjoint
@
(
sc
*
ivar
.
ptw
(
'log'
))
class
_SpecialGammaEnergy
(
EnergyOperator
):
class
_SpecialGammaEnergy
(
LikelihoodOperator
):
def
__init__
(
self
,
residual
):
def
__init__
(
self
,
residual
):
self
.
_domain
=
DomainTuple
.
make
(
residual
.
domain
)
self
.
_domain
=
DomainTuple
.
make
(
residual
.
domain
)
self
.
_resi
=
residual
self
.
_resi
=
residual
self
.
_cplx
=
_iscomplex
(
self
.
_resi
.
dtype
)
self
.
_cplx
=
_iscomplex
(
self
.
_resi
.
dtype
)
self
.
_
scale
=
ScalingOperator
(
self
.
_domain
,
1
if
self
.
_cplx
else
.
5
)
self
.
_
dt
=
self
.
_resi
.
dtype
def
apply
(
self
,
x
):
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
...
@@ -214,11 +257,13 @@ class _SpecialGammaEnergy(EnergyOperator):
...
@@ -214,11 +257,13 @@ class _SpecialGammaEnergy(EnergyOperator):
res
=
0.5
*
((
r
*
x
).
vdot
(
r
)
-
x
.
ptw
(
"log"
).
sum
())
res
=
0.5
*
((
r
*
x
).
vdot
(
r
)
-
x
.
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
if
not
x
.
want_metric
:
return
res
return
res
met
=
makeOp
((
self
.
_scale
(
x
.
val
))
**
(
-
2
))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
return
res
.
add_metric
(
SamplingDtypeSetter
(
met
,
self
.
_resi
.
dtype
))
def
get_transformation
(
self
):
sc
=
1.
if
self
.
_cplx
else
np
.
sqrt
(
0.5
)
return
self
.
_dt
,
sc
*
ScalingOperator
(
self
.
_domain
,
1.
).
ptw
(
'log'
)
class
GaussianEnergy
(
Energy
Operator
):
class
GaussianEnergy
(
Likelihood
Operator
):
"""Computes a negative-log Gaussian.
"""Computes a negative-log Gaussian.
Represents up to constants in :math:`m`:
Represents up to constants in :math:`m`:
...
@@ -301,15 +346,22 @@ class GaussianEnergy(EnergyOperator):
...
@@ -301,15 +346,22 @@ class GaussianEnergy(EnergyOperator):
residual
=
x
if
self
.
_mean
is
None
else
x
-
self
.
_mean
residual
=
x
if
self
.
_mean
is
None
else
x
-
self
.
_mean
res
=
self
.
_op
(
residual
).
real
res
=
self
.
_op
(
residual
).
real
if
x
.
want_metric
:
if
x
.
want_metric
:
return
res
.
add_metric
(
self
.
_met
)
return
res
.
add_metric
(
self
.
get
_met
ric_at
(
x
.
val
)
)
return
res
return
res
def
get_transformation
(
self
):
icov
,
dtp
=
self
.
_met
,
None
if
isinstance
(
icov
,
SamplingDtypeSetter
):
dtp
=
icov
.
_dtype
icov
=
icov
.
_op
return
dtp
,
icov
.
get_sqrt
()
def
__repr__
(
self
):
def
__repr__
(
self
):
dom
=
'()'
if
isinstance
(
self
.
domain
,
DomainTuple
)
else
self
.
domain
.
keys
()
dom
=
'()'
if
isinstance
(
self
.
domain
,
DomainTuple
)
else
self
.
domain
.
keys
()
return
f
'GaussianEnergy
{
dom
}
'
return
f
'GaussianEnergy
{
dom
}
'
class
PoissonianEnergy
(
Energy
Operator
):
class
PoissonianEnergy
(
Likelihood
Operator
):
"""Computes likelihood Hamiltonians of expected count field constrained by
"""Computes likelihood Hamiltonians of expected count field constrained by
Poissonian count data.
Poissonian count data.
...
@@ -341,10 +393,12 @@ class PoissonianEnergy(EnergyOperator):
...
@@ -341,10 +393,12 @@ class PoissonianEnergy(EnergyOperator):
res
=
x
.
sum
()
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
res
=
x
.
sum
()
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
if
not
x
.
want_metric
:
if
not
x
.
want_metric
:
return
res
return
res
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
1.
/
x
.
val
),
np
.
float64
))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
return
np
.
float64
,
2.
*
ScalingOperator
(
self
.
_domain
,
1.
).
sqrt
()
class
InverseGammaLikelihood
(
Energy
Operator
):
class
InverseGammaLikelihood
(
Likelihood
Operator
):
"""Computes the negative log-likelihood of the inverse gamma distribution.
"""Computes the negative log-likelihood of the inverse gamma distribution.
It negative log-pdf(x) is given by
It negative log-pdf(x) is given by
...
@@ -385,13 +439,14 @@ class InverseGammaLikelihood(EnergyOperator):
...
@@ -385,13 +439,14 @@ class InverseGammaLikelihood(EnergyOperator):
res
=
x
.
ptw
(
"log"
).
vdot
(
self
.
_alphap1
)
+
x
.
ptw
(
"reciprocal"
).
vdot
(
self
.
_beta
)
res
=
x
.
ptw
(
"log"
).
vdot
(
self
.
_alphap1
)
+
x
.
ptw
(
"reciprocal"
).
vdot
(
self
.
_beta
)
if
not
x
.
want_metric
:
if
not
x
.
want_metric
:
return
res
return
res
met
=
makeOp
(
self
.
_alphap1
/
(
x
.
val
**
2
))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
if
self
.
_sampling_dtype
is
not
None
:
met
=
SamplingDtypeSetter
(
met
,
self
.
_sampling_dtype
)
return
res
.
add_metric
(
met
)
def
get_transformation
(
self
):
fact
=
self
.
_alphap1
.
ptw
(
'sqrt'
)
res
=
makeOp
(
fact
)
@
ScalingOperator
(
self
.
_domain
,
1.
).
ptw
(
'log'
)
return
self
.
_sampling_dtype
,
res
class
StudentTEnergy
(
Energy
Operator
):
class
StudentTEnergy
(
Likelihood
Operator
):
"""Computes likelihood energy corresponding to Student's t-distribution.
"""Computes likelihood energy corresponding to Student's t-distribution.
.. math ::
.. math ::
...
@@ -418,11 +473,17 @@ class StudentTEnergy(EnergyOperator):
...
@@ -418,11 +473,17 @@ class StudentTEnergy(EnergyOperator):
res
=
(((
self
.
_theta
+
1
)
/
2
)
*
(
x
**
2
/
self
.
_theta
).
ptw
(
"log1p"
)).
sum
()
res
=
(((
self
.
_theta
+
1
)
/
2
)
*
(
x
**
2
/
self
.
_theta
).
ptw
(
"log1p"
)).
sum
()
if
not
x
.
want_metric
:
if
not
x
.
want_metric
:
return
res
return
res
met
=
makeOp
((
self
.
_theta
+
1
)
/
(
self
.
_theta
+
3
),
self
.
domain
)
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
return
res
.
add_metric
(
SamplingDtypeSetter
(
met
,
np
.
float64
))
def
get_transformation
(
self
):
if
isinstance
(
self
.
_theta
,
Field
)
or
isinstance
(
self
.
_theta
,
MultiField
):
th
=
self
.
_theta
else
:
from
..extra
import
full
th
=
full
(
self
.
_domain
,
self
.
_theta
)
return
np
.
float64
,
makeOp
(((
th
+
1
)
/
(
th
+
3
)).
ptw
(
'sqrt'
))
class
BernoulliEnergy
(
Energy
Operator
):
class
BernoulliEnergy
(
Likelihood
Operator
):
"""Computes likelihood energy of expected event frequency constrained by
"""Computes likelihood energy of expected event frequency constrained by
event data.
event data.
...
@@ -452,9 +513,13 @@ class BernoulliEnergy(EnergyOperator):
...
@@ -452,9 +513,13 @@ class BernoulliEnergy(EnergyOperator):
res
=
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
+
(
1.
-
x
).
ptw
(
"log"
).
vdot
(
self
.
_d
-
1.
)
res
=
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
+
(
1.
-
x
).
ptw
(
"log"
).
vdot
(
self
.
_d
-
1.
)
if
not
x
.
want_metric
:
if
not
x
.
want_metric
:
return
res
return
res
met
=
makeOp
(
1.
/
(
x
.
val
*
(
1.
-
x
.
val
)))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
return
res
.
add_metric
(
SamplingDtypeSetter
(
met
,
np
.
float64
))
def
get_transformation
(
self
):
from
..extra
import
full
res
=
Adder
(
full
(
self
.
_domain
,
1.
))
@
ScalingOperator
(
self
.
_domain
,
-
1
)
res
=
res
*
ScalingOperator
(
self
.
_domain
,
1
).
ptw
(
'reciprocal'
)
return
np
.
float64
,
-
2.
*
res
.
ptw
(
'sqrt'
).
ptw
(
'arctan'
)
class
StandardHamiltonian
(
EnergyOperator
):
class
StandardHamiltonian
(
EnergyOperator
):
"""Computes an information Hamiltonian in its standard form, i.e. with the
"""Computes an information Hamiltonian in its standard form, i.e. with the
...
@@ -546,3 +611,11 @@ class AveragedEnergy(EnergyOperator):
...
@@ -546,3 +611,11 @@ class AveragedEnergy(EnergyOperator):
self
.
_check_input
(
x
)
self
.
_check_input
(
x
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
)
return
utilities
.
my_sum
(
mymap
)
/
len
(
self
.
_res_samples
)
return
utilities
.
my_sum
(
mymap
)
/
len
(
self
.
_res_samples
)
def
get_transformation
(
self
):
dtp
,
trafo
=
self
.
_h
.
get_transformation
()
mymap
=
map
(
lambda
v
:
trafo
@
Adder
(
v
),
self
.
_res_samples
)
return
dtp
,
utilities
.
my_sum
(
mymap
)
/
np
.
sqrt
(
len
(
self
.
_res_samples
))
src/operators/operator.py
View file @
21f0128d
...
@@ -107,6 +107,23 @@ class Operator(metaclass=NiftyMeta):
...
@@ -107,6 +107,23 @@ class Operator(metaclass=NiftyMeta):
"""
"""
return
None
return
None
def
get_transformation
(
self
):
"""The coordinate transformation that maps into a coordinate system
where the metric of a likelihood is the Euclidean metric.
This is `None`, except when the object is considered a likelihood i.E.
for an instance of `EnergyOperator` with its metric being a proper
Fisher information metric, or a sum or nested sum thereof.
Retruns
-------
np.dtype, or dict of np.dtype : The dtype(s) of the target space of the
transformation.
Operator : The transformation that maps from `domain` into the
Euclidean target space.
"""
return
None
@
staticmethod
@
staticmethod
def
_check_domain_equality
(
dom_op
,
dom_field
):
def
_check_domain_equality
(
dom_op
,
dom_field
):
if
dom_op
!=
dom_field
:
if
dom_op
!=
dom_field
:
...
@@ -402,6 +419,12 @@ class _OpChain(_CombinedOperator):
...
@@ -402,6 +419,12 @@ class _OpChain(_CombinedOperator):
x
=
op
(
x
)
x
=
op
(
x
)
return
x
return
x
def
get_transformation
(
self
):
tr
=
self
.
_ops
[
0
].
get_transformation
()
if
tr
is
None
:
return
tr
return
tr
[
0
],
_OpChain
.
make
((
tr
[
1
],)
+
self
.
_ops
[
1
:])
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
...
@@ -488,6 +511,25 @@ class _OpSum(Operator):
...
@@ -488,6 +511,25 @@ class _OpSum(Operator):
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
return
res
return
res
def
get_transformation
(
self
):
tr1
=
self
.
_op1
.
get_transformation
()
tr2
=
self
.
_op2
.
get_transformation
()
if
tr1
is
None
or
tr2
is
None
: