Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
bd91033f
Commit
bd91033f
authored
Mar 25, 2020
by
Martin Reinecke
Browse files
turn the samples property into a generator
parent
e66e8b5e
Pipeline
#71449
canceled with stages
in 4 minutes and 28 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/minimization/metric_gaussian_kl.py
View file @
bd91033f
...
...
@@ -36,6 +36,18 @@ def _shareRange(nwork, nshares, myshare):
return
lo
,
hi
def
_getTask
(
iwork
,
nwork
,
nshares
):
nbase
=
nwork
//
nshares
additional
=
nwork
%
nshares
# FIXME: this is crappy code!
for
ishare
in
range
(
nshares
):
lo
=
ishare
*
nbase
+
min
(
ishare
,
additional
)
hi
=
lo
+
nbase
+
int
(
ishare
<
additional
)
if
hi
>
iwork
:
return
ishare
raise
RunTimeError
(
"must not arrive here"
)
def
_np_allreduce_sum
(
comm
,
arr
):
if
comm
is
None
:
return
arr
...
...
@@ -121,7 +133,7 @@ class MetricGaussianKL(Energy):
the presence of this parameter is that metric of the likelihood energy
is just an `Operator` which does not know anything about the dtype of
the fields on which it acts. Default is float64.
_samples : None
_local
_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
...
...
@@ -138,7 +150,7 @@ class MetricGaussianKL(Energy):
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
comm
=
None
,
_samples
=
None
,
napprox
=
0
,
comm
=
None
,
_local
_samples
=
None
,
lh_sampling_dtype
=
np
.
float64
):
super
(
MetricGaussianKL
,
self
).
__init__
(
mean
)
...
...
@@ -170,31 +182,31 @@ class MetricGaussianKL(Energy):
if
self
.
_mirror_samples
:
self
.
_n_eff_samples
*=
2
if
_samples
is
None
:
if
_local
_samples
is
None
:
met
=
hamiltonian
(
Linearization
.
make_partial_var
(
mean
,
self
.
_point_estimates
,
True
)).
metric
if
napprox
>=
1
:
met
.
_approximation
=
makeOp
(
approximation2endo
(
met
,
napprox
))
_samples
=
[]
_local
_samples
=
[]
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
in
range
(
self
.
_lo
,
self
.
_hi
):
random
.
push_sseq
(
sseq
[
i
])
_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
,
dtype
=
lh_sampling_dtype
))
_local
_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
,
dtype
=
lh_sampling_dtype
))
random
.
pop_sseq
()
_samples
=
tuple
(
_samples
)
_local
_samples
=
tuple
(
_
local_
samples
)
else
:
if
len
(
_samples
)
!=
self
.
_hi
-
self
.
_lo
:
if
len
(
_
local_
samples
)
!=
self
.
_hi
-
self
.
_lo
:
raise
ValueError
(
"# of samples mismatch"
)
self
.
_samples
=
_samples
self
.
_
local_
samples
=
_local
_samples
self
.
_lin
=
Linearization
.
make_partial_var
(
mean
,
self
.
_constants
)
v
,
g
=
None
,
None
if
len
(
self
.
_samples
)
==
0
:
# hack if there are too many MPI tasks
if
len
(
self
.
_
local_
samples
)
==
0
:
# hack if there are too many MPI tasks
tmp
=
self
.
_hamiltonian
(
self
.
_lin
)
v
=
0.
*
tmp
.
val
.
val
g
=
0.
*
tmp
.
gradient
else
:
for
s
in
self
.
_samples
:
for
s
in
self
.
_
local_
samples
:
tmp
=
self
.
_hamiltonian
(
self
.
_lin
+
s
)
if
self
.
_mirror_samples
:
tmp
=
tmp
+
self
.
_hamiltonian
(
self
.
_lin
-
s
)
...
...
@@ -213,7 +225,7 @@ class MetricGaussianKL(Energy):
return
MetricGaussianKL
(
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_constants
,
self
.
_point_estimates
,
self
.
_mirror_samples
,
comm
=
self
.
_comm
,
_samples
=
self
.
_samples
,
lh_sampling_dtype
=
self
.
_sampdt
)
_local
_samples
=
self
.
_
local_
samples
,
lh_sampling_dtype
=
self
.
_sampdt
)
@
property
def
value
(
self
):
...
...
@@ -226,15 +238,15 @@ class MetricGaussianKL(Energy):
def
_get_metric
(
self
):
lin
=
self
.
_lin
.
with_want_metric
()
if
self
.
_metric
is
None
:
if
len
(
self
.
_samples
)
==
0
:
# hack if there are too many MPI tasks
if
len
(
self
.
_
local_
samples
)
==
0
:
# hack if there are too many MPI tasks
self
.
_metric
=
self
.
_hamiltonian
(
lin
).
metric
.
scale
(
0.
)
else
:
mymap
=
map
(
lambda
v
:
self
.
_hamiltonian
(
lin
+
v
).
metric
,
self
.
_samples
)
self
.
_
local_
samples
)
unscaled_metric
=
utilities
.
my_sum
(
mymap
)
if
self
.
_mirror_samples
:
mymap
=
map
(
lambda
v
:
self
.
_hamiltonian
(
lin
-
v
).
metric
,
self
.
_samples
)
self
.
_
local_
samples
)
unscaled_metric
=
unscaled_metric
+
utilities
.
my_sum
(
mymap
)
self
.
_metric
=
unscaled_metric
.
scale
(
1.
/
self
.
_n_eff_samples
)
...
...
@@ -248,14 +260,21 @@ class MetricGaussianKL(Energy):
@
property
def
samples
(
self
):
if
self
.
_comm
is
not
None
:
res
=
self
.
_comm
.
allgather
(
self
.
_samples
)
res
=
[
item
for
sublist
in
res
for
item
in
sublist
]
if
self
.
_comm
is
None
:
for
s
in
self
.
_local_samples
:
yield
s
if
self
.
_mirror_samples
:
yield
-
s
else
:
res
=
self
.
_samples
if
self
.
_mirror_samples
:
res
=
res
+
tuple
(
-
item
for
item
in
res
)
return
res
ntask
=
self
.
_comm
.
Get_size
()
rank
=
self
.
_comm
.
Get_rank
()
for
i
in
range
(
self
.
_n_samples
):
itask
=
_getTask
(
i
,
self
.
_n_samples
,
ntask
)
data
=
self
.
_local_samples
[
i
-
self
.
_lo
]
if
rank
==
itask
else
None
s
=
self
.
_comm
.
bcast
(
data
,
root
=
itask
)
yield
s
if
self
.
_mirror_samples
:
yield
-
s
def
_metric_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
from_inverse
:
...
...
@@ -263,7 +282,7 @@ class MetricGaussianKL(Energy):
lin
=
self
.
_lin
.
with_want_metric
()
samp
=
full
(
self
.
_hamiltonian
.
domain
,
0.
)
sseq
=
random
.
spawn_sseq
(
self
.
_n_samples
)
for
i
,
v
in
enumerate
(
self
.
_samples
):
for
i
,
v
in
enumerate
(
self
.
_
local_
samples
):
random
.
push_sseq
(
sseq
[
self
.
_lo
+
i
])
samp
=
samp
+
self
.
_hamiltonian
(
lin
+
v
).
metric
.
draw_sample
(
from_inverse
=
False
,
dtype
=
dtype
)
if
self
.
_mirror_samples
:
...
...
test/test_kl.py
View file @
bd91033f
...
...
@@ -45,13 +45,13 @@ def test_kl(constants, point_estimates, mirror_samples):
point_estimates
=
point_estimates
,
mirror_samples
=
mirror_samples
,
napprox
=
0
)
samp
_full
=
kl
.
samples
loc
samp
=
kl
.
_local_
samples
klpure
=
ift
.
MetricGaussianKL
(
mean0
,
h
,
len
(
samp_full
)
,
mirror_samples
=
False
,
nsamps
,
mirror_samples
=
mirror_samples
,
napprox
=
0
,
_samples
=
samp
_full
)
_local
_samples
=
loc
samp
)
# Test value
assert_allclose
(
kl
.
value
,
klpure
.
value
)
...
...
@@ -66,7 +66,7 @@ def test_kl(constants, point_estimates, mirror_samples):
# Test number of samples
expected_nsamps
=
2
*
nsamps
if
mirror_samples
else
nsamps
assert_
(
len
(
kl
.
samples
)
==
expected_nsamps
)
assert_
(
len
(
tuple
(
kl
.
samples
)
)
==
expected_nsamps
)
# Test point_estimates (after drawing samples)
for
kk
in
point_estimates
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment