Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2aceefe5
Commit
2aceefe5
authored
May 27, 2020
by
Reimar Leike
Browse files
Made _sumup of KL faster by summign up in parallel when possible
parent
a8c78eb9
Pipeline
#75626
passed with stages
in 13 minutes and 16 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/minimization/metric_gaussian_kl.py
View file @
2aceefe5
...
...
@@ -228,6 +228,7 @@ class MetricGaussianKL(Energy):
if
self
.
_mirror_samples
:
yield
-
s
def
_sumup
(
self
,
obj
):
# This is a deterministic implementation of MPI allreduce in the sense
# that it takes into account that floating point operations are not
...
...
@@ -240,13 +241,44 @@ class MetricGaussianKL(Energy):
ntask
=
self
.
_comm
.
Get_size
()
rank
=
self
.
_comm
.
Get_rank
()
rank_lo_hi
=
[
_shareRange
(
self
.
_n_samples
,
ntask
,
i
)
for
i
in
range
(
ntask
)]
for
itask
,
(
l
,
h
)
in
enumerate
(
rank_lo_hi
):
for
i
in
range
(
l
,
h
):
o
=
obj
[
i
-
self
.
_lo
]
if
rank
==
itask
else
None
o
=
self
.
_comm
.
bcast
(
o
,
root
=
itask
)
res
=
o
if
res
is
None
else
res
+
o
lo
=
rank_lo_hi
[
rank
][
0
]
hi
=
rank_lo_hi
[
rank
][
1
]
vals
=
[]
for
i
in
range
(
self
.
_n_samples
):
if
(
i
>=
lo
)
and
(
i
<
hi
):
vals
+=
[
obj
[
i
-
lo
]]
else
:
vals
+=
[
None
]
who
=
np
.
zeros
(
len
(
vals
),
dtype
=
np
.
int32
)
for
t
,
(
l
,
h
)
in
enumerate
(
rank_lo_hi
):
who
[
l
:
h
]
=
t
def
add2
(
v
,
w
,
rank
):
#Note that communication only happens if rank in w
if
len
(
v
)
==
1
:
return
v
[
0
]
if
rank
==
w
[
0
]:
if
w
[
0
]
==
w
[
1
]:
return
v
[
0
]
+
v
[
1
]
self
.
_comm
.
send
(
v
[
0
],
dest
=
w
[
1
])
return
None
if
rank
==
w
[
1
]:
return
self
.
_comm
.
recv
(
source
=
w
[
0
])
+
v
[
1
]
while
len
(
vals
)
>
1
:
new_vals
=
[]
new_who
=
[]
for
j
in
range
((
len
(
vals
)
+
1
)
//
2
):
w
=
who
[
2
*
j
:
2
*
j
+
2
]
nv
=
add2
(
vals
[
2
*
j
:
2
*
j
+
2
],
w
,
rank
)
new_vals
+=
[
nv
]
new_who
+=
[
w
[
-
1
]]
vals
=
new_vals
who
=
new_who
res
=
self
.
_comm
.
bcast
(
vals
[
0
],
root
=
who
[
0
])
return
res
def
_metric_sample
(
self
,
from_inverse
=
False
):
if
from_inverse
:
raise
NotImplementedError
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment