Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
a87a957c
Commit
a87a957c
authored
Jun 19, 2020
by
Martin Reinecke
Browse files
Merge branch 'rework_kl' into 'NIFTy_7'
Rework kl See merge request
!540
parents
fb0a4dc3
74430ca9
Pipeline
#76953
passed with stages
in 12 minutes and 22 seconds
Changes
11
Pipelines
1
Show whitespace changes
Inline
Side-by-side
ChangeLog.md
View file @
a87a957c
Changes since NIFTy 6
Changes since NIFTy 6
=====================
=====================
*None.*
MetricGaussianKL interface
--------------------------
Users do not instanciate
`MetricGaussianKL`
by its constructor anymore. Rather
`MetricGaussianKL.make()`
shall be used.
Changes since NIFTy 5
Changes since NIFTy 5
...
...
demos/getting_started_3.py
View file @
a87a957c
...
@@ -131,7 +131,7 @@ def main():
...
@@ -131,7 +131,7 @@ def main():
# Draw new samples to approximate the KL five times
# Draw new samples to approximate the KL five times
for
i
in
range
(
5
):
for
i
in
range
(
5
):
# Draw new samples and minimize KL
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
KL
,
convergence
=
minimizer
(
KL
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
mean
=
KL
.
position
...
@@ -144,7 +144,7 @@ def main():
...
@@ -144,7 +144,7 @@ def main():
name
=
filename
.
format
(
"loop_{:02d}"
.
format
(
i
)))
name
=
filename
.
format
(
"loop_{:02d}"
.
format
(
i
)))
# Draw posterior samples
# Draw posterior samples
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
sc
=
ift
.
StatCalculator
()
sc
=
ift
.
StatCalculator
()
for
sample
in
KL
.
samples
:
for
sample
in
KL
.
samples
:
sc
.
add
(
signal
(
sample
+
KL
.
position
))
sc
.
add
(
signal
(
sample
+
KL
.
position
))
...
...
demos/getting_started_5_mf.py
View file @
a87a957c
...
@@ -131,7 +131,7 @@ def main():
...
@@ -131,7 +131,7 @@ def main():
for
i
in
range
(
10
):
for
i
in
range
(
10
):
# Draw new samples and minimize KL
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
KL
,
convergence
=
minimizer
(
KL
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
mean
=
KL
.
position
...
@@ -157,7 +157,7 @@ def main():
...
@@ -157,7 +157,7 @@ def main():
name
=
filename
.
format
(
"loop_{:02d}"
.
format
(
i
)))
name
=
filename
.
format
(
"loop_{:02d}"
.
format
(
i
)))
# Done, draw posterior samples
# Done, draw posterior samples
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
sc
=
ift
.
StatCalculator
()
sc
=
ift
.
StatCalculator
()
scA1
=
ift
.
StatCalculator
()
scA1
=
ift
.
StatCalculator
()
scA2
=
ift
.
StatCalculator
()
scA2
=
ift
.
StatCalculator
()
...
...
demos/mgvi_visualized.py
View file @
a87a957c
...
@@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm
...
@@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm
import
nifty7
as
ift
import
nifty7
as
ift
def
main
():
def
main
():
dom
=
ift
.
UnstructuredDomain
(
1
)
dom
=
ift
.
UnstructuredDomain
(
1
)
scale
=
10
scale
=
10
...
@@ -90,7 +91,7 @@ def main():
...
@@ -90,7 +91,7 @@ def main():
plt
.
figure
(
figsize
=
[
12
,
8
])
plt
.
figure
(
figsize
=
[
12
,
8
])
for
ii
in
range
(
15
):
for
ii
in
range
(
15
):
if
ii
%
3
==
0
:
if
ii
%
3
==
0
:
mgkl
=
ift
.
MetricGaussianKL
(
pos
,
ham
,
40
)
mgkl
=
ift
.
MetricGaussianKL
.
make
(
pos
,
ham
,
40
)
plt
.
cla
()
plt
.
cla
()
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(),
vmin
=
1e-3
,
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(),
vmin
=
1e-3
,
...
...
src/minimization/metric_gaussian_kl.py
View file @
a87a957c
...
@@ -24,7 +24,7 @@ from ..multi_field import MultiField
...
@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
..operators.energy_operators
import
StandardHamiltonian
from
..operators.energy_operators
import
StandardHamiltonian
from
..probing
import
approximation2endo
from
..probing
import
approximation2endo
from
..sugar
import
makeDomain
,
makeOp
from
..sugar
import
makeOp
from
.energy
import
Energy
from
.energy
import
Energy
...
@@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator):
...
@@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator):
return
self
.
_KL
.
_metric_sample
(
from_inverse
)
return
self
.
_KL
.
_metric_sample
(
from_inverse
)
def
_get_lo_hi
(
comm
,
n_samples
):
ntask
,
rank
,
_
=
utilities
.
get_MPI_params_from_comm
(
comm
)
return
utilities
.
shareRange
(
n_samples
,
ntask
,
rank
)
class
MetricGaussianKL
(
Energy
):
class
MetricGaussianKL
(
Energy
):
"""Provides the sampled Kullback-Leibler divergence between a distribution
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
and a Metric Gaussian.
...
@@ -58,6 +63,50 @@ class MetricGaussianKL(Energy):
...
@@ -58,6 +63,50 @@ class MetricGaussianKL(Energy):
true probability distribution the standard parametrization is assumed.
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
The samples of this class can be distributed among MPI tasks.
Notes
-----
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
mirror_samples
,
comm
,
local_samples
,
nanisinf
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
self
.
_hamiltonian
=
hamiltonian
self
.
_n_samples
=
int
(
n_samples
)
self
.
_mirror_samples
=
bool
(
mirror_samples
)
self
.
_comm
=
comm
self
.
_local_samples
=
local_samples
self
.
_nanisinf
=
bool
(
nanisinf
)
lin
=
Linearization
.
make_var
(
mean
)
v
,
g
=
[],
[]
for
s
in
self
.
_local_samples
:
tmp
=
hamiltonian
(
lin
+
s
)
tv
=
tmp
.
val
.
val
tg
=
tmp
.
gradient
if
mirror_samples
:
tmp
=
hamiltonian
(
lin
-
s
)
tv
=
tv
+
tmp
.
val
.
val
tg
=
tg
+
tmp
.
gradient
v
.
append
(
tv
)
g
.
append
(
tg
)
self
.
_val
=
utilities
.
allreduce_sum
(
v
,
self
.
_comm
)[()]
/
self
.
n_eff_samples
if
np
.
isnan
(
self
.
_val
)
and
self
.
_nanisinf
:
self
.
_val
=
np
.
inf
self
.
_grad
=
utilities
.
allreduce_sum
(
g
,
self
.
_comm
)
/
self
.
n_eff_samples
@
staticmethod
def
make
(
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
nanisinf
=
False
):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
Parameters
----------
----------
mean : Field
mean : Field
...
@@ -90,99 +139,53 @@ class MetricGaussianKL(Energy):
...
@@ -90,99 +139,53 @@ class MetricGaussianKL(Energy):
model are interpreted as inf. Thereby, the code does not crash on
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
has tried is not sensible.
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
Note
----
----
The two lists `constants` and `point_estimates` are independent from each
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
during minimization and vice versa.
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
"""
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
_local_samples
=
None
,
nanisinf
=
False
):
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
raise
TypeError
raise
TypeError
if
hamiltonian
.
domain
is
not
mean
.
domain
:
if
hamiltonian
.
domain
is
not
mean
.
domain
:
raise
ValueError
raise
ValueError
if
not
isinstance
(
n_samples
,
int
):
if
not
isinstance
(
n_samples
,
int
):
raise
TypeError
raise
TypeError
self
.
_mitigate_nans
=
nanisinf
if
not
isinstance
(
mirror_samples
,
bool
):
if
not
isinstance
(
mirror_samples
,
bool
):
raise
TypeError
raise
TypeError
if
isinstance
(
mean
,
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean
.
keys
()):
if
isinstance
(
mean
,
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean
.
keys
()):
raise
RuntimeError
(
raise
RuntimeError
(
'Point estimates for whole domain. Use EnergyAdapter instead.'
)
'Point estimates for whole domain. Use EnergyAdapter instead.'
)
n_samples
=
int
(
n_samples
)
mirror_samples
=
bool
(
mirror_samples
)
self
.
_hamiltonian
=
hamiltonian
if
isinstance
(
mean
,
MultiField
):
if
len
(
constants
)
>
0
:
cstpos
=
mean
.
extract_by_keys
(
point_estimates
)
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
constants
}
_
,
ham_sampling
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
dom
=
makeDomain
(
dom
)
else
:
cstpos
=
mean
.
extract
(
dom
)
ham_sampling
=
hamiltonian
_
,
self
.
_hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
met
=
ham_sampling
(
Linearization
.
make_var
(
mean
,
True
)).
metric
self
.
_n_samples
=
int
(
n_samples
)
self
.
_comm
=
comm
ntask
,
rank
,
_
=
utilities
.
get_MPI_params_from_comm
(
self
.
_comm
)
self
.
_lo
,
self
.
_hi
=
utilities
.
shareRange
(
self
.
_n_samples
,
ntask
,
rank
)
self
.
_mirror_samples
=
bool
(
mirror_samples
)
self
.
_n_eff_samples
=
self
.
_n_samples
if
self
.
_mirror_samples
:
self
.
_n_eff_samples
*=
2
if
_local_samples
is
None
:
if
len
(
point_estimates
)
>
0
:
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
point_estimates
}
dom
=
makeDomain
(
dom
)
cstpos
=
mean
.
extract
(
dom
)
_
,
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
met
=
hamiltonian
(
Linearization
.
make_var
(
mean
,
True
)).
metric
if
napprox
>=
1
:
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
_
local_samples
=
[]
local_samples
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_
n_samples
)
sseq
=
random
.
spawn_sseq
(
n_samples
)
for
i
in
range
(
self
.
_lo
,
self
.
_hi
):
for
i
in
range
(
*
_get_lo_hi
(
comm
,
n_samples
)
):
with
random
.
Context
(
sseq
[
i
]):
with
random
.
Context
(
sseq
[
i
]):
_local_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
))
local_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
))
_local_samples
=
tuple
(
_local_samples
)
local_samples
=
tuple
(
local_samples
)
else
:
if
len
(
_local_samples
)
!=
self
.
_hi
-
self
.
_lo
:
if
isinstance
(
mean
,
MultiField
):
raise
ValueError
(
"# of samples mismatch"
)
_
,
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
mean
.
extract_by_keys
(
constants
))
self
.
_local_samples
=
_local_samples
return
MetricGaussianKL
(
self
.
_lin
=
Linearization
.
make_var
(
mean
)
mean
,
hamiltonian
,
n_samples
,
mirror_samples
,
comm
,
local_samples
,
v
,
g
=
[],
[]
nanisinf
,
_callingfrommake
=
True
)
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_hamiltonian
(
self
.
_lin
+
s
)
tv
=
tmp
.
val
.
val
tg
=
tmp
.
gradient
if
self
.
_mirror_samples
:
tmp
=
self
.
_hamiltonian
(
self
.
_lin
-
s
)
tv
=
tv
+
tmp
.
val
.
val
tg
=
tg
+
tmp
.
gradient
v
.
append
(
tv
)
g
.
append
(
tg
)
self
.
_val
=
utilities
.
allreduce_sum
(
v
,
self
.
_comm
)[()]
/
self
.
_n_eff_samples
if
np
.
isnan
(
self
.
_val
)
and
self
.
_mitigate_nans
:
self
.
_val
=
np
.
inf
self
.
_grad
=
utilities
.
allreduce_sum
(
g
,
self
.
_comm
)
/
self
.
_n_eff_samples
def
at
(
self
,
position
):
def
at
(
self
,
position
):
return
MetricGaussianKL
(
return
MetricGaussianKL
(
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_mirror_samples
,
mirror_samples
=
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
self
.
_comm
,
self
.
_local_samples
,
self
.
_nanisinf
,
True
)
_local_samples
=
self
.
_local_samples
,
nanisinf
=
self
.
_mitigate_nans
)
@
property
@
property
def
value
(
self
):
def
value
(
self
):
...
@@ -193,14 +196,20 @@ class MetricGaussianKL(Energy):
...
@@ -193,14 +196,20 @@ class MetricGaussianKL(Energy):
return
self
.
_grad
return
self
.
_grad
def
apply_metric
(
self
,
x
):
def
apply_metric
(
self
,
x
):
lin
=
self
.
_lin
.
with_
want_metric
(
)
lin
=
Linearization
.
make_var
(
self
.
position
,
want_metric
=
True
)
res
=
[]
res
=
[]
for
s
in
self
.
_local_samples
:
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_hamiltonian
(
lin
+
s
).
metric
(
x
)
tmp
=
self
.
_hamiltonian
(
lin
+
s
).
metric
(
x
)
if
self
.
_mirror_samples
:
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
s
).
metric
(
x
)
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
s
).
metric
(
x
)
res
.
append
(
tmp
)
res
.
append
(
tmp
)
return
utilities
.
allreduce_sum
(
res
,
self
.
_comm
)
/
self
.
_n_eff_samples
return
utilities
.
allreduce_sum
(
res
,
self
.
_comm
)
/
self
.
n_eff_samples
@
property
def
n_eff_samples
(
self
):
if
self
.
_mirror_samples
:
return
2
*
self
.
_n_samples
return
self
.
_n_samples
@
property
@
property
def
metric
(
self
):
def
metric
(
self
):
...
@@ -216,9 +225,10 @@ class MetricGaussianKL(Energy):
...
@@ -216,9 +225,10 @@ class MetricGaussianKL(Energy):
yield
-
s
yield
-
s
else
:
else
:
rank_lo_hi
=
[
utilities
.
shareRange
(
self
.
_n_samples
,
ntask
,
i
)
for
i
in
range
(
ntask
)]
rank_lo_hi
=
[
utilities
.
shareRange
(
self
.
_n_samples
,
ntask
,
i
)
for
i
in
range
(
ntask
)]
lo
,
_
=
_get_lo_hi
(
self
.
_comm
,
self
.
_n_samples
)
for
itask
,
(
l
,
h
)
in
enumerate
(
rank_lo_hi
):
for
itask
,
(
l
,
h
)
in
enumerate
(
rank_lo_hi
):
for
i
in
range
(
l
,
h
):
for
i
in
range
(
l
,
h
):
data
=
self
.
_local_samples
[
i
-
self
.
_
lo
]
if
rank
==
itask
else
None
data
=
self
.
_local_samples
[
i
-
lo
]
if
rank
==
itask
else
None
s
=
self
.
_comm
.
bcast
(
data
,
root
=
itask
)
s
=
self
.
_comm
.
bcast
(
data
,
root
=
itask
)
yield
s
yield
s
if
self
.
_mirror_samples
:
if
self
.
_mirror_samples
:
...
@@ -231,7 +241,7 @@ class MetricGaussianKL(Energy):
...
@@ -231,7 +241,7 @@ class MetricGaussianKL(Energy):
' not take point_estimates into accout. Make sure that this '
' not take point_estimates into accout. Make sure that this '
'is your intended use.'
)
'is your intended use.'
)
logger
.
warning
(
s
)
logger
.
warning
(
s
)
lin
=
self
.
_lin
.
with_want_metric
(
)
lin
=
Linearization
.
make_var
(
self
.
position
,
True
)
samp
=
[]
samp
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
,
v
in
enumerate
(
self
.
_local_samples
):
for
i
,
v
in
enumerate
(
self
.
_local_samples
):
...
@@ -240,4 +250,4 @@ class MetricGaussianKL(Energy):
...
@@ -240,4 +250,4 @@ class MetricGaussianKL(Energy):
if
self
.
_mirror_samples
:
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
v
).
metric
.
draw_sample
(
from_inverse
=
False
)
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
v
).
metric
.
draw_sample
(
from_inverse
=
False
)
samp
.
append
(
tmp
)
samp
.
append
(
tmp
)
return
utilities
.
allreduce_sum
(
samp
,
self
.
_comm
)
/
self
.
_
n_eff_samples
return
utilities
.
allreduce_sum
(
samp
,
self
.
_comm
)
/
self
.
n_eff_samples
src/multi_field.py
View file @
a87a957c
...
@@ -248,6 +248,10 @@ class MultiField(Operator):
...
@@ -248,6 +248,10 @@ class MultiField(Operator):
return
MultiField
(
subset
,
return
MultiField
(
subset
,
tuple
(
self
[
key
]
for
key
in
subset
.
keys
()))
tuple
(
self
[
key
]
for
key
in
subset
.
keys
()))
def
extract_by_keys
(
self
,
keys
):
dom
=
MultiDomain
.
make
({
kk
:
vv
for
kk
,
vv
in
self
.
domain
.
items
()
if
kk
in
keys
})
return
self
.
extract
(
dom
)
def
extract_part
(
self
,
subset
):
def
extract_part
(
self
,
subset
):
if
subset
is
self
.
_domain
:
if
subset
is
self
.
_domain
:
return
self
return
self
...
...
src/operators/operator.py
View file @
a87a957c
...
@@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta):
...
@@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta):
from
.simplify_for_const
import
ConstantEnergyOperator
,
ConstantOperator
from
.simplify_for_const
import
ConstantEnergyOperator
,
ConstantOperator
if
c_inp
is
None
:
if
c_inp
is
None
:
return
None
,
self
return
None
,
self
dom
=
c_inp
.
domain
if
isinstance
(
dom
,
MultiDomain
)
and
len
(
dom
)
==
0
:
return
None
,
self
# Convention: If c_inp is MultiField, it needs to be defined on a
# Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain
# subdomain of self._domain
if
isinstance
(
self
.
domain
,
MultiDomain
):
if
isinstance
(
self
.
domain
,
MultiDomain
):
assert
isinstance
(
c_inp
.
domain
,
MultiDomain
)
assert
isinstance
(
dom
,
MultiDomain
)
if
set
(
c_inp
.
keys
())
>
set
(
self
.
domain
.
keys
()):
if
set
(
c_inp
.
keys
())
>
set
(
self
.
domain
.
keys
()):
raise
ValueError
raise
ValueError
if
c_inp
.
domain
is
self
.
domain
:
if
dom
is
self
.
domain
:
if
isinstance
(
self
,
EnergyOperator
):
if
isinstance
(
self
,
EnergyOperator
):
op
=
ConstantEnergyOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantEnergyOperator
(
self
.
domain
,
self
(
c_inp
))
else
:
else
:
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
return
op
(
c_inp
),
op
return
op
(
c_inp
),
op
if
not
isinstance
(
c_inp
.
domain
,
MultiDomain
):
if
not
isinstance
(
dom
,
MultiDomain
):
raise
RuntimeError
raise
RuntimeError
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
...
...
src/sugar.py
View file @
a87a957c
...
@@ -520,7 +520,7 @@ def calculate_position(operator, output):
...
@@ -520,7 +520,7 @@ def calculate_position(operator, output):
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
for
ii
in
range
(
3
):
for
ii
in
range
(
3
):
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
kl
=
MetricGaussianKL
(
pos
,
H
,
3
,
mirror_samples
=
True
)
kl
=
MetricGaussianKL
.
make
(
pos
,
H
,
3
,
mirror_samples
=
True
)
kl
,
_
=
minimizer
(
kl
)
kl
,
_
=
minimizer
(
kl
)
pos
=
kl
.
position
pos
=
kl
.
position
return
pos
return
pos
test/test_kl.py
View file @
a87a957c
...
@@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
...
@@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
'hamiltonian'
:
h
}
'hamiltonian'
:
h
}
if
isinstance
(
mean0
,
ift
.
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean0
.
keys
()):
if
isinstance
(
mean0
,
ift
.
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean0
.
keys
()):
with
assert_raises
(
RuntimeError
):
with
assert_raises
(
RuntimeError
):
ift
.
MetricGaussianKL
(
**
args
)
ift
.
MetricGaussianKL
.
make
(
**
args
)
return
return
kl
=
ift
.
MetricGaussianKL
(
**
args
)
kl
=
ift
.
MetricGaussianKL
.
make
(
**
args
)
assert_
(
len
(
ic
.
history
)
>
0
)
assert_
(
len
(
ic
.
history
)
>
0
)
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
time_stamps
))
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
time_stamps
))
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
energy_values
))
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
energy_values
))
...
@@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
...
@@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
energy_values
))
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
energy_values
))
locsamp
=
kl
.
_local_samples
locsamp
=
kl
.
_local_samples
klpure
=
ift
.
MetricGaussianKL
(
mean0
,
if
isinstance
(
mean0
,
ift
.
MultiField
):
h
,
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
nsamps
,
else
:
mirror_samples
=
mirror_samples
,
tmph
=
h
constants
=
constants
,
klpure
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
nsamps
,
mirror_samples
,
None
,
locsamp
,
False
,
True
)
point_estimates
=
point_estimates
,
_local_samples
=
locsamp
)
# Test number of samples
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
...
...
test/test_mpi/test_kl.py
View file @
a87a957c
...
@@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
...
@@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
'hamiltonian'
:
h
}
'hamiltonian'
:
h
}
if
isinstance
(
mean0
,
ift
.
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean0
.
keys
()):
if
isinstance
(
mean0
,
ift
.
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean0
.
keys
()):
with
assert_raises
(
RuntimeError
):
with
assert_raises
(
RuntimeError
):
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
)
ift
.
MetricGaussianKL
.
make
(
**
args
,
comm
=
comm
)
return
return
if
mode
==
0
:
if
mode
==
0
:
kl0
=
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
)
kl0
=
ift
.
MetricGaussianKL
.
make
(
**
args
,
comm
=
comm
)
locsamp
=
kl0
.
_local_samples
locsamp
=
kl0
.
_local_samples
kl1
=
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
,
_local_samples
=
locsamp
)
if
isinstance
(
mean0
,
ift
.
MultiField
):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
else
:
tmph
=
h
kl1
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
elif
mode
==
1
:
elif
mode
==
1
:
kl0
=
ift
.
MetricGaussianKL
(
**
args
)
kl0
=
ift
.
MetricGaussianKL
.
make
(
**
args
)
samples
=
kl0
.
_local_samples
samples
=
kl0
.
_local_samples
ii
=
len
(
samples
)
//
2
ii
=
len
(
samples
)
//
2
slc
=
slice
(
None
,
ii
)
if
rank
==
0
else
slice
(
ii
,
None
)
slc
=
slice
(
None
,
ii
)
if
rank
==
0
else
slice
(
ii
,
None
)
locsamp
=
samples
[
slc
]
locsamp
=
samples
[
slc
]
kl1
=
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
,
_local_samples
=
locsamp
)
if
isinstance
(
mean0
,
ift
.
MultiField
):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
else
:
tmph
=
h
kl1
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
# Test number of samples
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
...
...
test/test_sugar.py
View file @
a87a957c
...
@@ -16,11 +16,15 @@
...
@@ -16,11 +16,15 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
import
numpy
as
np
import
pytest
from
numpy.testing
import
assert_equal
from
numpy.testing
import
assert_equal
import
nifty7
as
ift
import
nifty7
as
ift
from
.common
import
setup_function
,
teardown_function
from
.common
import
setup_function
,
teardown_function
pmp
=
pytest
.
mark
.
parametrize
def
test_get_signal_variance
():
def
test_get_signal_variance
():
space
=
ift
.
RGSpace
(
3
)
space
=
ift
.
RGSpace
(
3
)
...
@@ -45,15 +49,13 @@ def test_exec_time():
...
@@ -45,15 +49,13 @@ def test_exec_time():
lh
=
ift
.
GaussianEnergy
(
domain
=
op
.
target
,
sampling_dtype
=
np
.
float64
)
@
op1
lh
=
ift
.
GaussianEnergy
(
domain
=
op
.
target
,
sampling_dtype
=
np
.
float64
)
@
op1
ic
=
ift
.
GradientNormController
(
iteration_limit
=
2
)
ic
=
ift
.
GradientNormController
(
iteration_limit
=
2
)