Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
193a276f
Commit
193a276f
authored
Mar 26, 2020
by
Martin Reinecke
Browse files
Merge branch 'KL_sample_generator' into 'NIFTy_6'
Use a generator for MetricGaussianKL.samples See merge request
!432
parents
c01d10cc
8e377ee6
Pipeline
#71475
passed with stages
in 27 minutes and 46 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/minimization/metric_gaussian_kl.py
View file @
193a276f
...
...
@@ -121,7 +121,7 @@ class MetricGaussianKL(Energy):
the presence of this parameter is that metric of the likelihood energy
is just an `Operator` which does not know anything about the dtype of
the fields on which it acts. Default is float64.
_samples : None
_local
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
...
...
@@ -138,7 +138,7 @@ class MetricGaussianKL(Energy):
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
_samples
=
None
,
napprox
=
0
,
comm
=
None
,
_local
_samples
=
None
,
lh_sampling_dtype
=
np
.
float64
):
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
...
...
@@ -170,31 +170,31 @@ class MetricGaussianKL(Energy):
if
self
.
_mirror_samples
:
self
.
_n_eff_samples
*=
2
if
_samples
is
None
:
if
_local
_samples
is
None
:
met
=
hamiltonian
(
Linearization
.
make_partial_var
(
mean
,
self
.
_point_estimates
,
True
)).
metric
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
_samples
=
[]
_local
_samples
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
in
range
(
self
.
_lo
,
self
.
_hi
):
random
.
push_sseq
(
sseq
[
i
])
_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
,
dtype
=
lh_sampling_dtype
))
_local
_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
,
dtype
=
lh_sampling_dtype
))
random
.
pop_sseq
()
_samples
=
tuple
(
_samples
)
_local
_samples
=
tuple
(
_
local_
samples
)
else
:
if
len
(
_samples
)
!=
self
.
_hi
-
self
.
_lo
:
if
len
(
_
local_
samples
)
!=
self
.
_hi
-
self
.
_lo
:
raise
ValueError
(
"# of samples mismatch"
)
self
.
_samples
=
_samples
self
.
_
local_
samples
=
_local
_samples
self
.
_lin
=
Linearization
.
make_partial_var
(
mean
,
self
.
_constants
)
v
,
g
=
None
,
None
if
len
(
self
.
_samples
)
==
0
:
# hack if there are too many MPI tasks
if
len
(
self
.
_
local_
samples
)
==
0
:
# hack if there are too many MPI tasks
tmp
=
self
.
_hamiltonian
(
self
.
_lin
)
v
=
0.
*
tmp
.
val
.
val
g
=
0.
*
tmp
.
gradient
else
:
for
s
in
self
.
_samples
:
for
s
in
self
.
_
local_
samples
:
tmp
=
self
.
_hamiltonian
(
self
.
_lin
+
s
)
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
self
.
_lin
-
s
)
...
...
@@ -213,7 +213,7 @@ class MetricGaussianKL(Energy):
return
MetricGaussianKL
(
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_constants
,
self
.
_point_estimates
,
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
_samples
=
self
.
_samples
,
lh_sampling_dtype
=
self
.
_sampdt
)
_local
_samples
=
self
.
_
local_
samples
,
lh_sampling_dtype
=
self
.
_sampdt
)
@
property
def
value
(
self
):
...
...
@@ -226,15 +226,15 @@ class MetricGaussianKL(Energy):
def
_get_metric
(
self
):
lin
=
self
.
_lin
.
with_want_metric
()
if
self
.
_metric
is
None
:
if
len
(
self
.
_samples
)
==
0
:
# hack if there are too many MPI tasks
if
len
(
self
.
_
local_
samples
)
==
0
:
# hack if there are too many MPI tasks
self
.
_metric
=
self
.
_hamiltonian
(
lin
).
metric
.
scale
(
0.
)
else
:
mymap
=
map
(
lambda
v
:
self
.
_hamiltonian
(
lin
+
v
).
metric
,
self
.
_samples
)
self
.
_
local_
samples
)
unscaled_metric
=
utilities
.
my_sum
(
mymap
)
if
self
.
_mirror_samples
:
mymap
=
map
(
lambda
v
:
self
.
_hamiltonian
(
lin
-
v
).
metric
,
self
.
_samples
)
self
.
_
local_
samples
)
unscaled_metric
=
unscaled_metric
+
utilities
.
my_sum
(
mymap
)
self
.
_metric
=
unscaled_metric
.
scale
(
1.
/
self
.
_n_eff_samples
)
...
...
@@ -248,14 +248,22 @@ class MetricGaussianKL(Energy):
@
property
def
samples
(
self
):
if
self
.
_comm
is
not
None
:
res
=
self
.
_comm
.
allgather
(
self
.
_samples
)
res
=
tuple
(
item
for
sublist
in
res
for
item
in
sublist
)
if
self
.
_comm
is
None
:
for
s
in
self
.
_local_samples
:
yield
s
if
self
.
_mirror_samples
:
yield
-
s
else
:
res
=
self
.
_samples
if
self
.
_mirror_samples
:
res
=
res
+
tuple
(
-
item
for
item
in
res
)
return
res
ntask
=
self
.
_comm
.
Get_size
()
rank
=
self
.
_comm
.
Get_rank
()
rank_lo_hi
=
[
_shareRange
(
self
.
_n_samples
,
ntask
,
i
)
for
i
in
range
(
ntask
)]
for
itask
,
(
l
,
h
)
in
enumerate
(
rank_lo_hi
):
for
i
in
range
(
l
,
h
):
data
=
self
.
_local_samples
[
i
-
self
.
_lo
]
if
rank
==
itask
else
None
s
=
self
.
_comm
.
bcast
(
data
,
root
=
itask
)
yield
s
if
self
.
_mirror_samples
:
yield
-
s
def
_metric_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
from_inverse
:
...
...
@@ -263,7 +271,7 @@ class MetricGaussianKL(Energy):
lin
=
self
.
_lin
.
with_want_metric
()
samp
=
full
(
self
.
_hamiltonian
.
domain
,
0.
)
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
,
v
in
enumerate
(
self
.
_samples
):
for
i
,
v
in
enumerate
(
self
.
_
local_
samples
):
random
.
push_sseq
(
sseq
[
self
.
_lo
+
i
])
samp
=
samp
+
self
.
_hamiltonian
(
lin
+
v
).
metric
.
draw_sample
(
from_inverse
=
False
,
dtype
=
dtype
)
if
self
.
_mirror_samples
:
...
...
test/test_kl.py
View file @
193a276f
...
...
@@ -45,13 +45,13 @@ def test_kl(constants, point_estimates, mirror_samples):
point_estimates
=
point_estimates
,
mirror_samples
=
mirror_samples
,
napprox
=
0
)
samp
_full
=
kl
.
samples
loc
samp
=
kl
.
_local_
samples
klpure
=
ift
.
MetricGaussianKL
(
mean0
,
h
,
len
(
samp_full
)
,
mirror_samples
=
False
,
nsamps
,
mirror_samples
=
mirror_samples
,
napprox
=
0
,
_samples
=
samp
_full
)
_local
_samples
=
loc
samp
)
# Test value
assert_allclose
(
kl
.
value
,
klpure
.
value
)
...
...
@@ -66,7 +66,7 @@ def test_kl(constants, point_estimates, mirror_samples):
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
assert_
(
len
(
kl
.
samples
)
==
expected_nsamps
)
assert_
(
len
(
tuple
(
kl
.
samples
)
)
==
expected_nsamps
)
# Test point_estimates (after drawing samples)
for
kk
in
point_estimates
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment