Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
74a9b76c
Commit
74a9b76c
authored
Apr 29, 2016
by
theos
Browse files
Fixed the exception handling in _selective_allreduce.
parent
b511fde7
Pipeline
#2152
skipped
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
d2o/distributor_factory.py
View file @
74a9b76c
...
@@ -503,6 +503,8 @@ class _slicing_distributor(distributor):
...
@@ -503,6 +503,8 @@ class _slicing_distributor(distributor):
rank
=
self
.
comm
.
rank
rank
=
self
.
comm
.
rank
if
size
==
1
:
if
size
==
1
:
if
data
is
None
:
raise
ValueError
(
"ERROR: No process with non-None data."
)
result_data
=
data
result_data
=
data
else
:
else
:
...
@@ -511,26 +513,31 @@ class _slicing_distributor(distributor):
...
@@ -511,26 +513,31 @@ class _slicing_distributor(distributor):
if
data
is
None
:
if
data
is
None
:
got_array
=
np
.
array
([
0
])
got_array
=
np
.
array
([
0
])
elif
not
isinstance
(
data
,
np
.
ndarray
):
elif
not
isinstance
(
data
,
np
.
ndarray
):
got_array
=
np
.
array
([
2
])
elif
reduce
(
lambda
x
,
y
:
x
*
y
,
data
.
shape
)
==
0
:
got_array
=
np
.
array
([
1
])
got_array
=
np
.
array
([
1
])
elif
np
.
issubdtype
(
data
.
dtype
,
np
.
complexfloating
):
elif
np
.
issubdtype
(
data
.
dtype
,
np
.
complexfloating
):
# MPI.MAX and MPI.MIN do not support complex data types
# MPI.MAX and MPI.MIN do not support complex data types
got_array
=
np
.
array
([
2
])
else
:
got_array
=
np
.
array
([
3
])
got_array
=
np
.
array
([
3
])
else
:
got_array
=
np
.
array
([
4
])
got_array_list
=
np
.
empty
(
size
,
dtype
=
np
.
int
)
got_array_list
=
np
.
empty
(
size
,
dtype
=
np
.
int
)
self
.
comm
.
Allgather
([
got_array
,
MPI
.
INT
],
self
.
comm
.
Allgather
([
got_array
,
MPI
.
INT
],
[
got_array_list
,
MPI
.
INT
])
[
got_array_list
,
MPI
.
INT
])
if
reduce
(
lambda
x
,
y
:
x
&
y
,
got_array_list
==
1
):
return
data
# get first node with non-None data
# get first node with non-None data
try
:
try
:
start
=
next
(
i
for
i
in
xrange
(
size
)
if
got_array_list
[
i
]
>
0
)
start
=
next
(
i
for
i
in
xrange
(
size
)
if
got_array_list
[
i
]
>
1
)
except
(
StopIteration
):
except
(
StopIteration
):
raise
ValueError
(
"ERROR: No process with non-None data."
)
raise
ValueError
(
"ERROR: No process with non-None data."
)
# check if the Uppercase function can be used or not
# check if the Uppercase function can be used or not
# -> check if op supports buffers and if we got real array-data
# -> check if op supports buffers and if we got real array-data
if
bufferQ
and
got_array
[
start
]
==
3
:
if
bufferQ
and
got_array
[
start
]
==
4
:
# Send the dtype and shape from the start process to the others
# Send the dtype and shape from the start process to the others
(
new_dtype
,
(
new_dtype
,
new_shape
)
=
self
.
comm
.
bcast
((
data
.
dtype
,
new_shape
)
=
self
.
comm
.
bcast
((
data
.
dtype
,
...
@@ -544,7 +551,7 @@ class _slicing_distributor(distributor):
...
@@ -544,7 +551,7 @@ class _slicing_distributor(distributor):
self
.
comm
.
Bcast
([
result_data
,
mpi_dtype
],
root
=
start
)
self
.
comm
.
Bcast
([
result_data
,
mpi_dtype
],
root
=
start
)
for
i
in
xrange
(
start
+
1
,
size
):
for
i
in
xrange
(
start
+
1
,
size
):
if
got_array_list
[
i
]:
if
got_array_list
[
i
]
>
1
:
if
rank
==
i
:
if
rank
==
i
:
temp_data
=
data
temp_data
=
data
else
:
else
:
...
@@ -555,7 +562,7 @@ class _slicing_distributor(distributor):
...
@@ -555,7 +562,7 @@ class _slicing_distributor(distributor):
else
:
else
:
result_data
=
self
.
comm
.
bcast
(
data
,
root
=
start
)
result_data
=
self
.
comm
.
bcast
(
data
,
root
=
start
)
for
i
in
xrange
(
start
+
1
,
size
):
for
i
in
xrange
(
start
+
1
,
size
):
if
got_array_list
[
i
]:
if
got_array_list
[
i
]
>
1
:
temp_data
=
self
.
comm
.
bcast
(
data
,
root
=
i
)
temp_data
=
self
.
comm
.
bcast
(
data
,
root
=
i
)
result_data
=
op
(
result_data
,
temp_data
)
result_data
=
op
(
result_data
,
temp_data
)
return
result_data
return
result_data
...
@@ -575,21 +582,10 @@ class _slicing_distributor(distributor):
...
@@ -575,21 +582,10 @@ class _slicing_distributor(distributor):
local_data
=
parent
.
data
local_data
=
parent
.
data
# if all local data is empty and empty_contractions are forbidden
try
:
# call function on the local_data in order to raise the right exception
if
self
.
global_dim
==
0
and
not
allow_empty_contractions
:
# this shall raise an exception
function
(
local_data
,
axis
=
axis
,
**
kwargs
)
# do the contraction on the node's local data
if
self
.
local_dim
==
0
and
not
allow_empty_contractions
:
# this case will only be reached if some nodes have data and some
# not
contracted_local_data
=
None
else
:
# if local_dim == 0 but empty contractions will be allowed
# this will be a `contraction neutral` array.
contracted_local_data
=
function
(
local_data
,
axis
=
axis
,
**
kwargs
)
contracted_local_data
=
function
(
local_data
,
axis
=
axis
,
**
kwargs
)
except
(
ValueError
):
contracted_local_data
=
None
# check if additional contraction along the first axis must be done
# check if additional contraction along the first axis must be done
if
axis
is
None
or
0
in
axis
:
if
axis
is
None
or
0
in
axis
:
...
@@ -600,6 +596,9 @@ class _slicing_distributor(distributor):
...
@@ -600,6 +596,9 @@ class _slicing_distributor(distributor):
bufferQ
)
bufferQ
)
new_dist_strategy
=
'not'
new_dist_strategy
=
'not'
else
:
else
:
if
contracted_local_data
is
None
:
# raise the exception implicitly
function
(
local_data
,
axis
=
axis
,
**
kwargs
)
contracted_global_data
=
contracted_local_data
contracted_global_data
=
contracted_local_data
new_dist_strategy
=
parent
.
distribution_strategy
new_dist_strategy
=
parent
.
distribution_strategy
...
...
test/test_nifty_mpi_data.py
View file @
74a9b76c
...
@@ -1752,7 +1752,7 @@ class Test_axis(unittest.TestCase):
...
@@ -1752,7 +1752,7 @@ class Test_axis(unittest.TestCase):
all_distribution_strategies
,
all_distribution_strategies
,
[
None
,
(
0
,
),
(
1
,
),
(
0
,
1
)]),
[
None
,
(
0
,
),
(
1
,
),
(
0
,
1
)]),
testcase_func_name
=
custom_name_func
)
testcase_func_name
=
custom_name_func
)
def
test_axis_with_operations_0_dimen
t
ion
(
self
,
function
,
dtype
,
def
test_axis_with_operations_0_dimen
s
ion
(
self
,
function
,
dtype
,
global_shape
,
global_shape
,
distribution_strategy
,
axis
):
distribution_strategy
,
axis
):
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
(
a
,
obj
)
=
generate_data
(
global_shape
,
dtype
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment