Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
21f0128d
Commit
21f0128d
authored
May 28, 2021
by
Philipp Frank
Committed by
Philipp Arras
May 28, 2021
Browse files
Implement Geometric KL
parent
0acc9ff7
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
ChangeLog.md
View file @
21f0128d
...
...
@@ -64,13 +64,29 @@ The implementation tests for nonlinear operators are now available in
`ift.extra.check_operator()`
and for linear operators
`ift.extra.check_linear_operator()`
.
MetricGaussianKL interface
--------------------------
Users do not instantiate
`MetricGaussianKL`
by its constructor anymore. Rather
`MetricGaussianKL.make()`
shall be used. Additionally,
`mirror_samples`
is not
set by default anymore.
`mirror_samples`
is not set by default anymore.
GeoMetricKL
-----------
A new posterior approximation scheme, called geometric Variational Inference
(geoVI) was introduced.
`GeoMetricKL`
is analogous to
`MetricGaussianKL`
with
the exception that geoVI samples are used instead of MGVI samples. For further
details see (
<https://arxiv.org/abs/2105.10470>
).
LikelihoodOperator
------------------
A new subclass of
`EnergyOperator`
was introduced and all
`EnergyOperator`
s
that are likelihoods are now
`LikelihoodOperator`
s. A
`LikelihoodOperator`
has to implement the function
`get_transformation`
, which returns a
coordinate transformation in which the Fisher metric of the likelihood becomes
the identity matrix.
Changes since NIFTy 5
...
...
demos/getting_started_3.py
View file @
21f0128d
...
...
@@ -123,7 +123,7 @@ def main():
# Draw new samples to approximate the KL five times
for
i
in
range
(
5
):
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
,
True
)
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
,
True
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
ift
.
extra
.
minisanity
(
data
,
lambda
x
:
N
.
inverse
,
signal_response
,
...
...
demos/getting_started_5_mf.py
View file @
21f0128d
...
...
@@ -134,7 +134,7 @@ def main():
for
i
in
range
(
5
):
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
,
True
)
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
,
True
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
...
...
demos/getting_started_density.py
View file @
21f0128d
...
...
@@ -94,7 +94,7 @@ if __name__ == "__main__":
for
i
in
range
(
5
):
# Draw new samples and minimize KL
kl
=
ift
.
MetricGaussianKL
.
make
(
mean
,
ham
,
n_samples
,
True
)
kl
=
ift
.
MetricGaussianKL
(
mean
,
ham
,
n_samples
,
True
)
kl
,
convergence
=
minimizer
(
kl
)
mean
=
kl
.
position
...
...
demos/mgvi_visualized.py
View file @
21f0128d
...
...
@@ -36,6 +36,8 @@ import nifty7 as ift
def
main
():
use_geo
=
False
name
=
'GEO'
if
use_geo
else
'MGVI'
dom
=
ift
.
UnstructuredDomain
(
1
)
scale
=
10
...
...
@@ -91,12 +93,16 @@ def main():
plt
.
figure
(
figsize
=
[
12
,
8
])
for
ii
in
range
(
15
):
if
ii
%
3
==
0
:
mgkl
=
ift
.
MetricGaussianKL
.
make
(
pos
,
ham
,
40
,
False
)
if
use_geo
:
mini_samp
=
ift
.
NewtonCG
(
ift
.
GradientNormController
(
iteration_limit
=
5
))
mgkl
=
ift
.
GeoMetricKL
(
pos
,
ham
,
100
,
mini_samp
,
False
)
else
:
mgkl
=
ift
.
MetricGaussianKL
(
pos
,
ham
,
100
,
False
)
plt
.
cla
()
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(),
vmin
=
1e-3
,
vmax
=
np
.
max
(
z
),
cmap
=
'gist_earth_r'
,
extent
=
x_limits_scaled
+
y_limits
)
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(
vmin
=
1e-3
,
vmax
=
np
.
max
(
z
)),
cmap
=
'gist_earth_r'
,
extent
=
x_limits_scaled
+
y_limits
)
if
ii
==
0
:
cbar
=
plt
.
colorbar
()
cbar
.
ax
.
set_ylabel
(
'pdf'
)
...
...
@@ -105,12 +111,14 @@ def main():
samp
=
(
samp
+
pos
).
val
xs
.
append
(
samp
[
'a'
])
ys
.
append
(
samp
[
'b'
])
plt
.
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
'MGVI samples'
)
plt
.
scatter
(
pos
.
val
[
'a'
]
*
scale
,
pos
.
val
[
'b'
],
label
=
'MGVI latent mean'
)
plt
.
scatter
(
np
.
array
(
map_xs
)
*
scale
,
np
.
array
(
map_ys
),
label
=
'Laplace samples'
)
plt
.
scatter
(
np
.
array
(
xs
)
*
scale
,
np
.
array
(
ys
),
label
=
name
+
' samples'
)
plt
.
scatter
(
pos
.
val
[
'a'
]
*
scale
,
pos
.
val
[
'b'
],
label
=
name
+
' latent mean'
)
plt
.
scatter
(
MAP
.
position
.
val
[
'a'
]
*
scale
,
MAP
.
position
.
val
[
'b'
],
label
=
'Maximum a posterior solution'
)
plt
.
xlim
(
x_limits_scaled
)
plt
.
ylim
(
y_limits
)
plt
.
legend
()
plt
.
draw
()
plt
.
pause
(
1.0
)
...
...
src/__init__.py
View file @
21f0128d
...
...
@@ -71,7 +71,7 @@ from .minimization.scipy_minimizer import L_BFGS_B
from
.minimization.energy
import
Energy
from
.minimization.quadratic_energy
import
QuadraticEnergy
from
.minimization.energy_adapter
import
EnergyAdapter
from
.minimization.
metric_gaussian_kl
import
MetricGaussianKL
from
.minimization.
kl_energies
import
MetricGaussianKL
,
GeoMetricKL
from
.sugar
import
*
...
...
src/extra.py
View file @
21f0128d
...
...
@@ -526,3 +526,22 @@ def _tableentries(redchisq, scmean, ndof, keylen):
out
+=
f
"
{
ndof
[
kk
]:
>
11
}
"
out
+=
"
\n
"
return
out
[:
-
1
]
class
_KeyModifier
(
LinearOperator
):
def
__init__
(
self
,
domain
,
pre
):
if
not
isinstance
(
domain
,
MultiDomain
):
raise
ValueError
from
.sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_pre
=
str
(
pre
)
target
=
{
self
.
_pre
+
k
:
domain
[
k
]
for
k
in
domain
.
keys
()}
self
.
_target
=
makeDomain
(
MultiDomain
.
make
(
target
))
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
res
=
{
self
.
_pre
+
k
:
x
[
k
]
for
k
in
self
.
_domain
.
keys
()}
else
:
res
=
{
k
:
x
[
self
.
_pre
+
k
]
for
k
in
self
.
_domain
.
keys
()}
return
MultiField
.
from_dict
(
res
,
domain
=
self
.
_tgt
(
mode
))
src/minimization/
metric_gaussian_kl
.py
→
src/minimization/
kl_energies
.py
View file @
21f0128d
This diff is collapsed.
Click to expand it.
src/operators/energy_operators.py
View file @
21f0128d
...
...
@@ -20,15 +20,18 @@ import numpy as np
from
..
import
utilities
from
..domain_tuple
import
DomainTuple
from
..field
import
Field
from
..linearization
import
Linearization
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
from
..sugar
import
makeDomain
,
makeOp
from
..utilities
import
myassert
from
.linear_operator
import
LinearOperator
from
.operator
import
Operator
from
.adder
import
Adder
from
.sampling_enabler
import
SamplingDtypeSetter
,
SamplingEnabler
from
.scaling_operator
import
ScalingOperator
from
.simple_linear_operators
import
VdotOperator
from
.sandwich_operator
import
SandwichOperator
from
.simple_linear_operators
import
VdotOperator
,
FieldAdapter
def
_check_sampling_dtype
(
domain
,
dtypes
):
...
...
@@ -78,6 +81,24 @@ class EnergyOperator(Operator):
_target
=
DomainTuple
.
scalar_domain
()
class
LikelihoodOperator
(
EnergyOperator
):
"""`EnergyOperator` representing a likelihood. The input to the Operator
are the parameters of the likelihood. Unlike a general `EnergyOperator`,
the metric of a `LikelihoodOperator` is the Fisher information metric of
the likelihood.
"""
def
get_metric_at
(
self
,
x
):
"""Computes the Fisher information metric for a `LikelihoodOperator`
at `x` using the Jacobian of the coordinate transformation given by
`get_transformation`.
"""
dtp
,
f
=
self
.
get_transformation
()
ch
=
ScalingOperator
(
f
.
target
,
1.
)
if
dtp
is
not
None
:
ch
=
SamplingDtypeSetter
(
ch
,
dtp
)
return
SandwichOperator
.
make
(
f
(
Linearization
.
make_var
(
x
)).
jac
,
ch
)
class
Squared2NormOperator
(
EnergyOperator
):
"""Computes the square of the L2-norm of the output of an operator.
...
...
@@ -126,7 +147,7 @@ class QuadraticFormOperator(EnergyOperator):
return
x
.
new
(
res
,
VdotOperator
(
self
.
_op
(
x
.
val
)))
class
VariableCovarianceGaussianEnergy
(
Energy
Operator
):
class
VariableCovarianceGaussianEnergy
(
Likelihood
Operator
):
"""Computes the negative log pdf of a Gaussian with unknown covariance.
The covariance is assumed to be diagonal.
...
...
@@ -152,9 +173,16 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
sampling_dtype : np.dtype
Data type of the samples. Usually either 'np.float*' or 'np.complex*'
use_full_fisher: boolean
Whether or not the proper Fisher information metric should be used as
a `metric`. If False the same approximation used in
`get_transformation` is used instead.
Default is True
"""
def
__init__
(
self
,
domain
,
residual_key
,
inverse_covariance_key
,
sampling_dtype
):
def
__init__
(
self
,
domain
,
residual_key
,
inverse_covariance_key
,
sampling_dtype
,
use_full_fisher
=
True
):
self
.
_kr
=
str
(
residual_key
)
self
.
_ki
=
str
(
inverse_covariance_key
)
dom
=
DomainTuple
.
make
(
domain
)
...
...
@@ -162,6 +190,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
self
.
_dt
=
{
self
.
_kr
:
sampling_dtype
,
self
.
_ki
:
np
.
float64
}
_check_sampling_dtype
(
self
.
_domain
,
self
.
_dt
)
self
.
_cplx
=
_iscomplex
(
sampling_dtype
)
self
.
_use_fisher
=
use_full_fisher
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
...
...
@@ -172,10 +201,14 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res
=
0.5
*
(
r
.
vdot
(
r
*
i
)
-
i
.
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
return
res
met
=
1.
if
self
.
_cplx
else
.
5
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
*
i
.
val
**
(
-
2
)},
domain
=
self
.
_domain
)
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
))
if
self
.
_use_fisher
:
met
=
1.
if
self
.
_cplx
else
0.5
met
=
MultiField
.
from_dict
({
self
.
_kr
:
i
.
val
,
self
.
_ki
:
met
*
i
.
val
**
(
-
2
)},
domain
=
self
.
_domain
)
met
=
SamplingDtypeSetter
(
makeOp
(
met
),
self
.
_dt
)
else
:
met
=
self
.
get_metric_at
(
x
.
val
)
return
res
.
add_metric
(
met
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
.simplify_for_const
import
ConstantEnergyOperator
...
...
@@ -190,20 +223,30 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res
=
GaussianEnergy
(
inverse_covariance
=
makeOp
(
cst
),
sampling_dtype
=
dt
).
ducktape
(
self
.
_kr
)
trlog
=
cst
.
log
().
sum
().
val_rw
()
if
not
_iscomplex
(
dt
)
:
if
not
self
.
_cplx
:
trlog
/=
2
res
=
res
+
ConstantEnergyOperator
(
-
trlog
)
res
=
res
+
ConstantEnergyOperator
(
0.
)
myassert
(
res
.
target
is
self
.
target
)
return
None
,
res
def
get_transformation
(
self
):
"""Note that for the metric of a `VariableCovarianceGaussianEnergy` no
global transformation to Euclidean space exists. A local approximation
ivoking the resudual is used instead.
"""
r
=
FieldAdapter
(
self
.
_domain
[
self
.
_kr
],
self
.
_kr
)
ivar
=
FieldAdapter
(
self
.
_domain
[
self
.
_kr
],
self
.
_ki
)
sc
=
1.
if
self
.
_cplx
else
0.5
return
self
.
_dt
,
r
.
adjoint
@
(
ivar
.
ptw
(
'sqrt'
)
*
r
)
+
ivar
.
adjoint
@
(
sc
*
ivar
.
ptw
(
'log'
))
class
_SpecialGammaEnergy
(
EnergyOperator
):
class
_SpecialGammaEnergy
(
LikelihoodOperator
):
def
__init__
(
self
,
residual
):
self
.
_domain
=
DomainTuple
.
make
(
residual
.
domain
)
self
.
_resi
=
residual
self
.
_cplx
=
_iscomplex
(
self
.
_resi
.
dtype
)
self
.
_
scale
=
ScalingOperator
(
self
.
_domain
,
1
if
self
.
_cplx
else
.
5
)
self
.
_
dt
=
self
.
_resi
.
dtype
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
...
...
@@ -214,11 +257,13 @@ class _SpecialGammaEnergy(EnergyOperator):
res
=
0.5
*
((
r
*
x
).
vdot
(
r
)
-
x
.
ptw
(
"log"
).
sum
())
if
not
x
.
want_metric
:
return
res
met
=
makeOp
((
self
.
_scale
(
x
.
val
))
**
(
-
2
))
return
res
.
add_metric
(
SamplingDtypeSetter
(
met
,
self
.
_resi
.
dtype
))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
sc
=
1.
if
self
.
_cplx
else
np
.
sqrt
(
0.5
)
return
self
.
_dt
,
sc
*
ScalingOperator
(
self
.
_domain
,
1.
).
ptw
(
'log'
)
class
GaussianEnergy
(
Energy
Operator
):
class
GaussianEnergy
(
Likelihood
Operator
):
"""Computes a negative-log Gaussian.
Represents up to constants in :math:`m`:
...
...
@@ -301,15 +346,22 @@ class GaussianEnergy(EnergyOperator):
residual
=
x
if
self
.
_mean
is
None
else
x
-
self
.
_mean
res
=
self
.
_op
(
residual
).
real
if
x
.
want_metric
:
return
res
.
add_metric
(
self
.
_met
)
return
res
.
add_metric
(
self
.
get
_met
ric_at
(
x
.
val
)
)
return
res
def
get_transformation
(
self
):
icov
,
dtp
=
self
.
_met
,
None
if
isinstance
(
icov
,
SamplingDtypeSetter
):
dtp
=
icov
.
_dtype
icov
=
icov
.
_op
return
dtp
,
icov
.
get_sqrt
()
def
__repr__
(
self
):
dom
=
'()'
if
isinstance
(
self
.
domain
,
DomainTuple
)
else
self
.
domain
.
keys
()
return
f
'GaussianEnergy
{
dom
}
'
class
PoissonianEnergy
(
Energy
Operator
):
class
PoissonianEnergy
(
Likelihood
Operator
):
"""Computes likelihood Hamiltonians of expected count field constrained by
Poissonian count data.
...
...
@@ -341,10 +393,12 @@ class PoissonianEnergy(EnergyOperator):
res
=
x
.
sum
()
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
SamplingDtypeSetter
(
makeOp
(
1.
/
x
.
val
),
np
.
float64
))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
return
np
.
float64
,
2.
*
ScalingOperator
(
self
.
_domain
,
1.
).
sqrt
()
class
InverseGammaLikelihood
(
Energy
Operator
):
class
InverseGammaLikelihood
(
Likelihood
Operator
):
"""Computes the negative log-likelihood of the inverse gamma distribution.
It negative log-pdf(x) is given by
...
...
@@ -385,13 +439,14 @@ class InverseGammaLikelihood(EnergyOperator):
res
=
x
.
ptw
(
"log"
).
vdot
(
self
.
_alphap1
)
+
x
.
ptw
(
"reciprocal"
).
vdot
(
self
.
_beta
)
if
not
x
.
want_metric
:
return
res
met
=
makeOp
(
self
.
_alphap1
/
(
x
.
val
**
2
))
if
self
.
_sampling_dtype
is
not
None
:
met
=
SamplingDtypeSetter
(
met
,
self
.
_sampling_dtype
)
return
res
.
add_metric
(
met
)
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
fact
=
self
.
_alphap1
.
ptw
(
'sqrt'
)
res
=
makeOp
(
fact
)
@
ScalingOperator
(
self
.
_domain
,
1.
).
ptw
(
'log'
)
return
self
.
_sampling_dtype
,
res
class
StudentTEnergy
(
Energy
Operator
):
class
StudentTEnergy
(
Likelihood
Operator
):
"""Computes likelihood energy corresponding to Student's t-distribution.
.. math ::
...
...
@@ -418,11 +473,17 @@ class StudentTEnergy(EnergyOperator):
res
=
(((
self
.
_theta
+
1
)
/
2
)
*
(
x
**
2
/
self
.
_theta
).
ptw
(
"log1p"
)).
sum
()
if
not
x
.
want_metric
:
return
res
met
=
makeOp
((
self
.
_theta
+
1
)
/
(
self
.
_theta
+
3
),
self
.
domain
)
return
res
.
add_metric
(
SamplingDtypeSetter
(
met
,
np
.
float64
))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
if
isinstance
(
self
.
_theta
,
Field
)
or
isinstance
(
self
.
_theta
,
MultiField
):
th
=
self
.
_theta
else
:
from
..extra
import
full
th
=
full
(
self
.
_domain
,
self
.
_theta
)
return
np
.
float64
,
makeOp
(((
th
+
1
)
/
(
th
+
3
)).
ptw
(
'sqrt'
))
class
BernoulliEnergy
(
Energy
Operator
):
class
BernoulliEnergy
(
Likelihood
Operator
):
"""Computes likelihood energy of expected event frequency constrained by
event data.
...
...
@@ -452,9 +513,13 @@ class BernoulliEnergy(EnergyOperator):
res
=
-
x
.
ptw
(
"log"
).
vdot
(
self
.
_d
)
+
(
1.
-
x
).
ptw
(
"log"
).
vdot
(
self
.
_d
-
1.
)
if
not
x
.
want_metric
:
return
res
met
=
makeOp
(
1.
/
(
x
.
val
*
(
1.
-
x
.
val
)))
return
res
.
add_metric
(
SamplingDtypeSetter
(
met
,
np
.
float64
))
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
def
get_transformation
(
self
):
from
..extra
import
full
res
=
Adder
(
full
(
self
.
_domain
,
1.
))
@
ScalingOperator
(
self
.
_domain
,
-
1
)
res
=
res
*
ScalingOperator
(
self
.
_domain
,
1
).
ptw
(
'reciprocal'
)
return
np
.
float64
,
-
2.
*
res
.
ptw
(
'sqrt'
).
ptw
(
'arctan'
)
class
StandardHamiltonian
(
EnergyOperator
):
"""Computes an information Hamiltonian in its standard form, i.e. with the
...
...
@@ -546,3 +611,11 @@ class AveragedEnergy(EnergyOperator):
self
.
_check_input
(
x
)
mymap
=
map
(
lambda
v
:
self
.
_h
(
x
+
v
),
self
.
_res_samples
)
return
utilities
.
my_sum
(
mymap
)
/
len
(
self
.
_res_samples
)
def
get_transformation
(
self
):
dtp
,
trafo
=
self
.
_h
.
get_transformation
()
mymap
=
map
(
lambda
v
:
trafo
@
Adder
(
v
),
self
.
_res_samples
)
return
dtp
,
utilities
.
my_sum
(
mymap
)
/
np
.
sqrt
(
len
(
self
.
_res_samples
))
src/operators/operator.py
View file @
21f0128d
...
...
@@ -107,6 +107,23 @@ class Operator(metaclass=NiftyMeta):
"""
return
None
def
get_transformation
(
self
):
"""The coordinate transformation that maps into a coordinate system
where the metric of a likelihood is the Euclidean metric.
This is `None`, except when the object is considered a likelihood i.E.
for an instance of `EnergyOperator` with its metric being a proper
Fisher information metric, or a sum or nested sum thereof.
Retruns
-------
np.dtype, or dict of np.dtype : The dtype(s) of the target space of the
transformation.
Operator : The transformation that maps from `domain` into the
Euclidean target space.
"""
return
None
@
staticmethod
def
_check_domain_equality
(
dom_op
,
dom_field
):
if
dom_op
!=
dom_field
:
...
...
@@ -402,6 +419,12 @@ class _OpChain(_CombinedOperator):
x
=
op
(
x
)
return
x
def
get_transformation
(
self
):
tr
=
self
.
_ops
[
0
].
get_transformation
()
if
tr
is
None
:
return
tr
return
tr
[
0
],
_OpChain
.
make
((
tr
[
1
],)
+
self
.
_ops
[
1
:])
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
...
...
@@ -488,6 +511,25 @@ class _OpSum(Operator):
res
=
res
.
add_metric
(
lin1
.
_metric
.
_myadd
(
lin2
.
_metric
,
False
))
return
res
def
get_transformation
(
self
):
tr1
=
self
.
_op1
.
get_transformation
()
tr2
=
self
.
_op2
.
get_transformation
()
if
tr1
is
None
or
tr2
is
None
:
return
None
from
..extra
import
_KeyModifier
dtype
,
trafo
=
{},
None
for
i
,
lh
in
enumerate
([
self
.
_op1
,
self
.
_op2
]):
dtp
,
tr
=
lh
.
get_transformation
()
if
isinstance
(
tr
.
target
,
MultiDomain
):
dtype
.
update
({
str
(
i
)
+
d
:
dtp
[
d
]
for
d
in
dtp
.
keys
()})
tr
=
_KeyModifier
(
tr
.
target
,
str
(
i
))
@
tr
trafo
=
tr
if
trafo
is
None
else
trafo
+
tr
else
:
dtype
[
str
(
i
)]
=
dtp
tr
=
tr
.
ducktape_left
(
str
(
i
))
trafo
=
tr
if
trafo
is
None
else
trafo
+
tr
return
dtype
,
trafo
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
from
.simplify_for_const
import
ConstCollector
...
...
src/operators/sandwich_operator.py
View file @
21f0128d
...
...
@@ -73,6 +73,9 @@ class SandwichOperator(EndomorphicOperator):
# If our sandwich is diagonal, we can return immediately
if
isinstance
(
op
,
(
ScalingOperator
,
DiagonalOperator
)):
if
isinstance
(
cheese
,
SamplingDtypeSetter
):
#FIXME
return
SamplingDtypeSetter
(
op
,
cheese
.
_dtype
)
return
op
return
SandwichOperator
(
bun
,
cheese
,
op
,
_callingfrommake
=
True
)
...
...
src/sugar.py
View file @
21f0128d
...
...
@@ -39,7 +39,7 @@ from .plot import Plot
__all__
=
[
'PS_field'
,
'power_analyze'
,
'create_power_operator'
,
'density_estimator'
,
'create_harmonic_smoothing_operator'
,
'from_random'
,
'full'
,
'makeField'
,
'is_fieldlike'
,
'is_linearization'
,
'is_operator'
,
'makeDomain'
,
'is_linearization'
,
'is_operator'
,
'makeDomain'
,
'is_likelihood'
,
'get_signal_variance'
,
'makeOp'
,
'domain_union'
,
'get_default_codomain'
,
'single_plot'
,
'exec_time'
,
'calculate_position'
]
+
list
(
pointwise
.
ptw_dict
.
keys
())
...
...
@@ -221,9 +221,9 @@ def density_estimator(domain, pad=1.0, cf_fluctuations=None,
from
.library.correlated_fields
import
CorrelatedFieldMaker
from
.library.special_distributions
import
UniformOperator
cf_azm_uniform_sane_default
=
(
1e-
15
,
5
.0
)
cf_azm_uniform_sane_default
=
(
1e-
4
,
1
.0
)
cf_fluctuations_sane_default
=
{
"scale"
:
(
0.
3
,
0.
2
),
"scale"
:
(
0.
5
,
0.
3
),
"cutoff"
:
(
4.0
,
3.0
),
"loglogslope"
:
(
-
6.0
,
3.0
)
}
...
...
@@ -557,7 +557,7 @@ def calculate_position(operator, output):
"""Finds approximate preimage of an operator for a given output."""
from
.minimization.descent_minimizers
import
NewtonCG
from
.minimization.iteration_controllers
import
GradientNormController
from
.minimization.
metric_gaussian_kl
import
MetricGaussianKL
from
.minimization.
kl_energies
import
MetricGaussianKL
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.energy_operators
import
GaussianEnergy
,
StandardHamiltonian
if
not
isinstance
(
operator
,
Operator
):
...
...
@@ -585,18 +585,24 @@ def calculate_position(operator, output):
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
for
ii
in
range
(
3
):
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
kl
=
MetricGaussianKL
.
make
(
pos
,
H
,
3
,
True
)
kl
=
MetricGaussianKL
(
pos
,
H
,
3
,
True
)
kl
,
_
=
minimizer
(
kl
)
pos
=
kl
.
position
return
pos
def
is_likelihood
(
obj
):
"""Checks if object is likelihood-like.
"""
return
isinstance
(
obj
,
Operator
)
and
obj
.
get_transformation
()
is
not
None
def
is_operator
(
obj
):
"""Check if object is operator-like.
"""Check
s
if object is operator-like.
Note
----
A simple `isinstance(obj, ift.Operator)` does
not
give the expected
A simple `isinstance(obj, ift.Operator)` does give the expected
result because, e.g., :class:`~nifty7.field.Field` inherits from
:class:`~nifty7.operators.operator.Operator`.
"""
...
...
@@ -604,20 +610,19 @@ def is_operator(obj):
def
is_linearization
(
obj
):
"""Check if object is linearization-like."""
"""Check
s
if object is linearization-like."""
return
isinstance
(
obj
,
Operator
)
and
obj
.
jac
is
not
None
def
is_fieldlike
(
obj
):
"""Check if object is field-like.
"""Check
s
if object is field-like.
Note
----
A simple `isinstance(obj, ift.Field)` does
not
give the expected
A simple `isinstance(obj, ift.Field)` does give the expected
result because users might have implemented another class which
behaves field-like but is not an instance of
:class:`~nifty7.field.Field`. Instances of
:class:`~nifty7.linearization.Linearization` are considered to be
field-like.
:class:`~nifty7.field.Field`. Also not that instances of
:class:`~nifty7.linearization.Linearization` behave field-like.
"""
return
isinstance
(
obj
,
Operator
)
and
obj
.
val
is
not
None
test/test_kl.py
View file @
21f0128d
...
...
@@ -31,7 +31,8 @@ pmp = pytest.mark.parametrize
@
pmp
(
'point_estimates'
,
([],
[
'a'
],
[
'b'
],
[
'a'
,
'b'
]))
@
pmp
(
'mirror_samples'
,
(
True
,
False
))
@
pmp
(
'mf'
,
(
True
,
False
))
def
test_kl
(
constants
,
point_estimates
,
mirror_samples
,
mf
):
@
pmp
(
'geo'
,
(
True
,
False
))
def
test_kl
(
constants
,
point_estimates
,
mirror_samples
,
mf
,
geo
):
if
not
mf
and
(
len
(
point_estimates
)
!=
0
or
len
(
constants
)
!=
0
):
return
dom
=
ift
.
RGSpace
((
12
,),
(
2.12
))
...
...
@@ -51,11 +52,19 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
'n_samples'
:
nsamps
,
'mean'
:
mean0
,
'hamiltonian'
:
h
}
if
geo
:
args
[
'minimizer_samp'
]
=
ift
.
NewtonCG
(
ic
)
if
isinstance
(
mean0
,
ift
.
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean0
.
keys
()):
with
assert_raises
(
RuntimeError
):
ift
.
MetricGaussianKL
.
make
(
**
args
)
if
geo
:
ift
.
GeoMetricKL
(
**
args
)
else
:
ift
.
MetricGaussianKL
(
**
args
)
return
kl
=
ift
.
MetricGaussianKL
.
make
(
**
args
)
if
geo
:
kl
=
ift
.
GeoMetricKL
(
**
args
)
else
:
kl
=
ift
.
MetricGaussianKL
(
**
args
)
myassert
(
len
(
ic
.
history
)
>
0
)
myassert
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
time_stamps
))
myassert
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
energy_values
))
...
...
@@ -71,8 +80,10 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
else
:
tmph
=
h
tmpmean
=
mean0
klpure
=
ift
.
MetricGaussianKL
(
tmpmean
,
tmph
,
nsamps
,
mirror_samples
,
None
,
locsamp
,
False
,
True
)
if
geo
and
mirror_samples
:
klpure
=
ift
.
minimization
.
kl_energies
.
_SampledKLEnergy
(
tmpmean
,
tmph
,
2
*
nsamps
,
False
,
None
,
locsamp
,
False
)
else
:
klpure
=
ift
.
minimization
.
kl_energies
.
_SampledKLEnergy
(
tmpmean
,
tmph
,
nsamps
,
mirror_samples
,
None
,
locsamp
,
False
)
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
myassert
(
len
(
tuple
(
kl
.
samples
))
==
expected_nsamps
)
...
...
test/test_mpi/test_kl.py