Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
a87a957c
Commit
a87a957c
authored
Jun 19, 2020
by
Martin Reinecke
Browse files
Merge branch 'rework_kl' into 'NIFTy_7'
Rework kl See merge request
!540
parents
fb0a4dc3
74430ca9
Pipeline
#76953
passed with stages
in 12 minutes and 22 seconds
Changes
11
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
ChangeLog.md
View file @
a87a957c
Changes since NIFTy 6
=====================
*None.*
MetricGaussianKL interface
--------------------------
Users do not instanciate
`MetricGaussianKL`
by its constructor anymore. Rather
`MetricGaussianKL.make()`
shall be used.
Changes since NIFTy 5
...
...
demos/getting_started_3.py
View file @
a87a957c
...
...
@@ -131,7 +131,7 @@ def main():
# Draw new samples to approximate the KL five times
for
i
in
range
(
5
):
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
...
...
@@ -144,7 +144,7 @@ def main():
name
=
filename
.
format
(
"loop_{:02d}"
.
format
(
i
)))
# Draw posterior samples
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
sc
=
ift
.
StatCalculator
()
for
sample
in
KL
.
samples
:
sc
.
add
(
signal
(
sample
+
KL
.
position
))
...
...
demos/getting_started_5_mf.py
View file @
a87a957c
...
...
@@ -131,7 +131,7 @@ def main():
for
i
in
range
(
10
):
# Draw new samples and minimize KL
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
KL
,
convergence
=
minimizer
(
KL
)
mean
=
KL
.
position
...
...
@@ -157,7 +157,7 @@ def main():
name
=
filename
.
format
(
"loop_{:02d}"
.
format
(
i
)))
# Done, draw posterior samples
KL
=
ift
.
MetricGaussianKL
(
mean
,
H
,
N_samples
)
KL
=
ift
.
MetricGaussianKL
.
make
(
mean
,
H
,
N_samples
)
sc
=
ift
.
StatCalculator
()
scA1
=
ift
.
StatCalculator
()
scA2
=
ift
.
StatCalculator
()
...
...
demos/mgvi_visualized.py
View file @
a87a957c
...
...
@@ -34,6 +34,7 @@ from matplotlib.colors import LogNorm
import
nifty7
as
ift
def
main
():
dom
=
ift
.
UnstructuredDomain
(
1
)
scale
=
10
...
...
@@ -90,7 +91,7 @@ def main():
plt
.
figure
(
figsize
=
[
12
,
8
])
for
ii
in
range
(
15
):
if
ii
%
3
==
0
:
mgkl
=
ift
.
MetricGaussianKL
(
pos
,
ham
,
40
)
mgkl
=
ift
.
MetricGaussianKL
.
make
(
pos
,
ham
,
40
)
plt
.
cla
()
plt
.
imshow
(
z
.
T
,
origin
=
'lower'
,
norm
=
LogNorm
(),
vmin
=
1e-3
,
...
...
src/minimization/metric_gaussian_kl.py
View file @
a87a957c
...
...
@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from
..operators.endomorphic_operator
import
EndomorphicOperator
from
..operators.energy_operators
import
StandardHamiltonian
from
..probing
import
approximation2endo
from
..sugar
import
makeDomain
,
makeOp
from
..sugar
import
makeOp
from
.energy
import
Energy
...
...
@@ -42,6 +42,11 @@ class _KLMetric(EndomorphicOperator):
return
self
.
_KL
.
_metric_sample
(
from_inverse
)
def
_get_lo_hi
(
comm
,
n_samples
):
ntask
,
rank
,
_
=
utilities
.
get_MPI_params_from_comm
(
comm
)
return
utilities
.
shareRange
(
n_samples
,
ntask
,
rank
)
class
MetricGaussianKL
(
Energy
):
"""Provides the sampled Kullback-Leibler divergence between a distribution
and a Metric Gaussian.
...
...
@@ -58,58 +63,89 @@ class MetricGaussianKL(Energy):
true probability distribution the standard parametrization is assumed.
The samples of this class can be distributed among MPI tasks.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
Notes
-----
DomainTuples should never be created using the constructor, but rather
via the factory function :attr:`make`!
See also
--------
`Metric Gaussian Variational Inference`, Jakob Knollmüller,
Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
"""
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
_local_samples
=
None
,
nanisinf
=
False
):
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
mirror_samples
,
comm
,
local_samples
,
nanisinf
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
self
.
_hamiltonian
=
hamiltonian
self
.
_n_samples
=
int
(
n_samples
)
self
.
_mirror_samples
=
bool
(
mirror_samples
)
self
.
_comm
=
comm
self
.
_local_samples
=
local_samples
self
.
_nanisinf
=
bool
(
nanisinf
)
lin
=
Linearization
.
make_var
(
mean
)
v
,
g
=
[],
[]
for
s
in
self
.
_local_samples
:
tmp
=
hamiltonian
(
lin
+
s
)
tv
=
tmp
.
val
.
val
tg
=
tmp
.
gradient
if
mirror_samples
:
tmp
=
hamiltonian
(
lin
-
s
)
tv
=
tv
+
tmp
.
val
.
val
tg
=
tg
+
tmp
.
gradient
v
.
append
(
tv
)
g
.
append
(
tg
)
self
.
_val
=
utilities
.
allreduce_sum
(
v
,
self
.
_comm
)[()]
/
self
.
n_eff_samples
if
np
.
isnan
(
self
.
_val
)
and
self
.
_nanisinf
:
self
.
_val
=
np
.
inf
self
.
_grad
=
utilities
.
allreduce_sum
(
g
,
self
.
_comm
)
/
self
.
n_eff_samples
@
staticmethod
def
make
(
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
nanisinf
=
False
):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
----------
mean : Field
Mean of the Gaussian probability distribution.
hamiltonian : StandardHamiltonian
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
point_estimates : list
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible
across this communicator. If `mirror_samples` is set, then a sample and
its mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on
these occaisions but rather the minimizer is told that the position it
has tried is not sensible.
Note
----
The two lists `constants` and `point_estimates` are independent from each
other. It is possible to sample along domains which are kept constant
during minimization and vice versa.
"""
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
raise
TypeError
...
...
@@ -117,72 +153,39 @@ class MetricGaussianKL(Energy):
raise
ValueError
if
not
isinstance
(
n_samples
,
int
):
raise
TypeError
self
.
_mitigate_nans
=
nanisinf
if
not
isinstance
(
mirror_samples
,
bool
):
raise
TypeError
if
isinstance
(
mean
,
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean
.
keys
()):
raise
RuntimeError
(
'Point estimates for whole domain. Use EnergyAdapter instead.'
)
n_samples
=
int
(
n_samples
)
mirror_samples
=
bool
(
mirror_samples
)
self
.
_hamiltonian
=
hamiltonian
if
len
(
constants
)
>
0
:
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
constants
}
dom
=
makeDomain
(
dom
)
cstpos
=
mean
.
extract
(
dom
)
_
,
self
.
_hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
self
.
_n_samples
=
int
(
n_samples
)
self
.
_comm
=
comm
ntask
,
rank
,
_
=
utilities
.
get_MPI_params_from_comm
(
self
.
_comm
)
self
.
_lo
,
self
.
_hi
=
utilities
.
shareRange
(
self
.
_n_samples
,
ntask
,
rank
)
self
.
_mirror_samples
=
bool
(
mirror_samples
)
self
.
_n_eff_samples
=
self
.
_n_samples
if
self
.
_mirror_samples
:
self
.
_n_eff_samples
*=
2
if
_local_samples
is
None
:
if
len
(
point_estimates
)
>
0
:
dom
=
{
kk
:
vv
for
kk
,
vv
in
mean
.
domain
.
items
()
if
kk
in
point_estimates
}
dom
=
makeDomain
(
dom
)
cstpos
=
mean
.
extract
(
dom
)
_
,
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
met
=
hamiltonian
(
Linearization
.
make_var
(
mean
,
True
)).
metric
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
_local_samples
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
in
range
(
self
.
_lo
,
self
.
_hi
):
with
random
.
Context
(
sseq
[
i
]):
_local_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
))
_local_samples
=
tuple
(
_local_samples
)
if
isinstance
(
mean
,
MultiField
):
cstpos
=
mean
.
extract_by_keys
(
point_estimates
)
_
,
ham_sampling
=
hamiltonian
.
simplify_for_constant_input
(
cstpos
)
else
:
if
len
(
_local_samples
)
!=
self
.
_hi
-
self
.
_lo
:
raise
ValueError
(
"# of samples mismatch"
)
self
.
_local_samples
=
_local_samples
self
.
_lin
=
Linearization
.
make_var
(
mean
)
v
,
g
=
[],
[]
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_hamiltonian
(
self
.
_lin
+
s
)
tv
=
tmp
.
val
.
val
tg
=
tmp
.
gradient
if
self
.
_mirror_samples
:
tmp
=
self
.
_hamiltonian
(
self
.
_lin
-
s
)
tv
=
tv
+
tmp
.
val
.
val
tg
=
tg
+
tmp
.
gradient
v
.
append
(
tv
)
g
.
append
(
tg
)
self
.
_val
=
utilities
.
allreduce_sum
(
v
,
self
.
_comm
)[()]
/
self
.
_n_eff_samples
if
np
.
isnan
(
self
.
_val
)
and
self
.
_mitigate_nans
:
self
.
_val
=
np
.
inf
self
.
_grad
=
utilities
.
allreduce_sum
(
g
,
self
.
_comm
)
/
self
.
_n_eff_samples
ham_sampling
=
hamiltonian
met
=
ham_sampling
(
Linearization
.
make_var
(
mean
,
True
)).
metric
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
local_samples
=
[]
sseq
=
random
.
spawn_sseq
(
n_samples
)
for
i
in
range
(
*
_get_lo_hi
(
comm
,
n_samples
)):
with
random
.
Context
(
sseq
[
i
]):
local_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
))
local_samples
=
tuple
(
local_samples
)
if
isinstance
(
mean
,
MultiField
):
_
,
hamiltonian
=
hamiltonian
.
simplify_for_constant_input
(
mean
.
extract_by_keys
(
constants
))
return
MetricGaussianKL
(
mean
,
hamiltonian
,
n_samples
,
mirror_samples
,
comm
,
local_samples
,
nanisinf
,
_callingfrommake
=
True
)
def
at
(
self
,
position
):
return
MetricGaussianKL
(
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
mirror_samples
=
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
_local_samples
=
self
.
_local_samples
,
nanisinf
=
self
.
_mitigate_nans
)
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_mirror_samples
,
self
.
_comm
,
self
.
_local_samples
,
self
.
_nanisinf
,
True
)
@
property
def
value
(
self
):
...
...
@@ -193,14 +196,20 @@ class MetricGaussianKL(Energy):
return
self
.
_grad
def
apply_metric
(
self
,
x
):
lin
=
self
.
_lin
.
with_
want_metric
(
)
lin
=
Linearization
.
make_var
(
self
.
position
,
want_metric
=
True
)
res
=
[]
for
s
in
self
.
_local_samples
:
tmp
=
self
.
_hamiltonian
(
lin
+
s
).
metric
(
x
)
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
s
).
metric
(
x
)
res
.
append
(
tmp
)
return
utilities
.
allreduce_sum
(
res
,
self
.
_comm
)
/
self
.
_n_eff_samples
return
utilities
.
allreduce_sum
(
res
,
self
.
_comm
)
/
self
.
n_eff_samples
@
property
def
n_eff_samples
(
self
):
if
self
.
_mirror_samples
:
return
2
*
self
.
_n_samples
return
self
.
_n_samples
@
property
def
metric
(
self
):
...
...
@@ -216,9 +225,10 @@ class MetricGaussianKL(Energy):
yield
-
s
else
:
rank_lo_hi
=
[
utilities
.
shareRange
(
self
.
_n_samples
,
ntask
,
i
)
for
i
in
range
(
ntask
)]
lo
,
_
=
_get_lo_hi
(
self
.
_comm
,
self
.
_n_samples
)
for
itask
,
(
l
,
h
)
in
enumerate
(
rank_lo_hi
):
for
i
in
range
(
l
,
h
):
data
=
self
.
_local_samples
[
i
-
self
.
_
lo
]
if
rank
==
itask
else
None
data
=
self
.
_local_samples
[
i
-
lo
]
if
rank
==
itask
else
None
s
=
self
.
_comm
.
bcast
(
data
,
root
=
itask
)
yield
s
if
self
.
_mirror_samples
:
...
...
@@ -231,7 +241,7 @@ class MetricGaussianKL(Energy):
' not take point_estimates into accout. Make sure that this '
'is your intended use.'
)
logger
.
warning
(
s
)
lin
=
self
.
_lin
.
with_want_metric
(
)
lin
=
Linearization
.
make_var
(
self
.
position
,
True
)
samp
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
,
v
in
enumerate
(
self
.
_local_samples
):
...
...
@@ -240,4 +250,4 @@ class MetricGaussianKL(Energy):
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
lin
-
v
).
metric
.
draw_sample
(
from_inverse
=
False
)
samp
.
append
(
tmp
)
return
utilities
.
allreduce_sum
(
samp
,
self
.
_comm
)
/
self
.
_
n_eff_samples
return
utilities
.
allreduce_sum
(
samp
,
self
.
_comm
)
/
self
.
n_eff_samples
src/multi_field.py
View file @
a87a957c
...
...
@@ -248,6 +248,10 @@ class MultiField(Operator):
return
MultiField
(
subset
,
tuple
(
self
[
key
]
for
key
in
subset
.
keys
()))
def
extract_by_keys
(
self
,
keys
):
dom
=
MultiDomain
.
make
({
kk
:
vv
for
kk
,
vv
in
self
.
domain
.
items
()
if
kk
in
keys
})
return
self
.
extract
(
dom
)
def
extract_part
(
self
,
subset
):
if
subset
is
self
.
_domain
:
return
self
...
...
src/operators/operator.py
View file @
a87a957c
...
...
@@ -275,22 +275,25 @@ class Operator(metaclass=NiftyMeta):
from
.simplify_for_const
import
ConstantEnergyOperator
,
ConstantOperator
if
c_inp
is
None
:
return
None
,
self
dom
=
c_inp
.
domain
if
isinstance
(
dom
,
MultiDomain
)
and
len
(
dom
)
==
0
:
return
None
,
self
# Convention: If c_inp is MultiField, it needs to be defined on a
# subdomain of self._domain
if
isinstance
(
self
.
domain
,
MultiDomain
):
assert
isinstance
(
c_inp
.
domain
,
MultiDomain
)
assert
isinstance
(
dom
,
MultiDomain
)
if
set
(
c_inp
.
keys
())
>
set
(
self
.
domain
.
keys
()):
raise
ValueError
if
c_inp
.
domain
is
self
.
domain
:
if
dom
is
self
.
domain
:
if
isinstance
(
self
,
EnergyOperator
):
op
=
ConstantEnergyOperator
(
self
.
domain
,
self
(
c_inp
))
else
:
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
op
=
ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
return
op
(
c_inp
),
op
if
not
isinstance
(
c_inp
.
domain
,
MultiDomain
):
if
not
isinstance
(
dom
,
MultiDomain
):
raise
RuntimeError
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
...
...
src/sugar.py
View file @
a87a957c
...
...
@@ -520,7 +520,7 @@ def calculate_position(operator, output):
minimizer
=
NewtonCG
(
GradientNormController
(
iteration_limit
=
10
,
name
=
'findpos'
))
for
ii
in
range
(
3
):
logger
.
info
(
f
'Start iteration
{
ii
+
1
}
/3'
)
kl
=
MetricGaussianKL
(
pos
,
H
,
3
,
mirror_samples
=
True
)
kl
=
MetricGaussianKL
.
make
(
pos
,
H
,
3
,
mirror_samples
=
True
)
kl
,
_
=
minimizer
(
kl
)
pos
=
kl
.
position
return
pos
test/test_kl.py
View file @
a87a957c
...
...
@@ -52,9 +52,9 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
'hamiltonian'
:
h
}
if
isinstance
(
mean0
,
ift
.
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean0
.
keys
()):
with
assert_raises
(
RuntimeError
):
ift
.
MetricGaussianKL
(
**
args
)
ift
.
MetricGaussianKL
.
make
(
**
args
)
return
kl
=
ift
.
MetricGaussianKL
(
**
args
)
kl
=
ift
.
MetricGaussianKL
.
make
(
**
args
)
assert_
(
len
(
ic
.
history
)
>
0
)
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
time_stamps
))
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
energy_values
))
...
...
@@ -64,13 +64,11 @@ def test_kl(constants, point_estimates, mirror_samples, mf):
assert_
(
len
(
ic
.
history
)
==
len
(
ic
.
history
.
energy_values
))
locsamp
=
kl
.
_local_samples
klpure
=
ift
.
MetricGaussianKL
(
mean0
,
h
,
nsamps
,
mirror_samples
=
mirror_samples
,
constants
=
constants
,
point_estimates
=
point_estimates
,
_local_samples
=
locsamp
)
if
isinstance
(
mean0
,
ift
.
MultiField
):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
else
:
tmph
=
h
klpure
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
nsamps
,
mirror_samples
,
None
,
locsamp
,
False
,
True
)
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
...
...
test/test_mpi/test_kl.py
View file @
a87a957c
...
...
@@ -60,19 +60,27 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
'hamiltonian'
:
h
}
if
isinstance
(
mean0
,
ift
.
MultiField
)
and
set
(
point_estimates
)
==
set
(
mean0
.
keys
()):
with
assert_raises
(
RuntimeError
):
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
)
ift
.
MetricGaussianKL
.
make
(
**
args
,
comm
=
comm
)
return
if
mode
==
0
:
kl0
=
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
)
kl0
=
ift
.
MetricGaussianKL
.
make
(
**
args
,
comm
=
comm
)
locsamp
=
kl0
.
_local_samples
kl1
=
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
,
_local_samples
=
locsamp
)
if
isinstance
(
mean0
,
ift
.
MultiField
):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
else
:
tmph
=
h
kl1
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
elif
mode
==
1
:
kl0
=
ift
.
MetricGaussianKL
(
**
args
)
kl0
=
ift
.
MetricGaussianKL
.
make
(
**
args
)
samples
=
kl0
.
_local_samples
ii
=
len
(
samples
)
//
2
slc
=
slice
(
None
,
ii
)
if
rank
==
0
else
slice
(
ii
,
None
)
locsamp
=
samples
[
slc
]
kl1
=
ift
.
MetricGaussianKL
(
**
args
,
comm
=
comm
,
_local_samples
=
locsamp
)
if
isinstance
(
mean0
,
ift
.
MultiField
):
_
,
tmph
=
h
.
simplify_for_constant_input
(
mean0
.
extract_by_keys
(
constants
))
else
:
tmph
=
h
kl1
=
ift
.
MetricGaussianKL
(
mean0
,
tmph
,
2
,
mirror_samples
,
comm
,
locsamp
,
False
,
True
)
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
...
...
test/test_sugar.py
View file @
a87a957c
...
...
@@ -16,11 +16,15 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
import
pytest
from
numpy.testing
import
assert_equal
import
nifty7
as
ift
from
.common
import
setup_function
,
teardown_function
pmp
=
pytest
.
mark
.
parametrize
def
test_get_signal_variance
():
space
=
ift
.
RGSpace
(
3
)
...
...
@@ -45,15 +49,13 @@ def test_exec_time():
lh
=
ift
.
GaussianEnergy
(
domain
=
op
.
target
,
sampling_dtype
=
np
.
float64
)
@
op1
ic
=
ift
.
GradientNormController
(
iteration_limit
=
2
)
ham
=
ift
.
StandardHamiltonian
(
lh
,
ic_samp
=
ic
)
kl
=
ift
.
MetricGaussianKL
(
ift
.
full
(
ham
.
domain
,
0.
),
ham
,
1
)
kl
=
ift
.
MetricGaussianKL
.
make
(
ift
.
full
(
ham
.
domain
,
0.
),
ham
,
1
)
ops
=
[
op
,
op1
,
lh
,
ham
,
kl
]
for
oo
in
ops
:
for
wm
in
[
True
,
False
]:
ift
.
exec_time
(
oo
,
wm
)
import
pytest
pmp
=
pytest
.
mark
.
parametrize
@
pmp
(
'mf'
,
[
False
,
True
])
@
pmp
(
'cplx'
,
[
False
,
True
])
def
test_calc_pos
(
mf
,
cplx
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment