Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
f39083b1
Commit
f39083b1
authored
Apr 25, 2016
by
csongor
Browse files
fix test dimentions and some max and min cases
parent
08e80a04
Changes
3
Hide whitespace changes
Inline
Side-by-side
d2o/distributor_factory.py
View file @
f39083b1
...
...
@@ -517,14 +517,25 @@ class _slicing_distributor(distributor):
# check if additional contraction along the first axis must be done
if
axis
is
None
or
0
in
axis
:
(
mpi_op
,
bufferQ
)
=
op_translate_dict
[
function
]
contracted_local_data
=
self
.
comm
.
allreduce
(
contracted_local_data
,
op
=
mpi_op
)
if
bufferQ
and
isinstance
(
contracted_local_data
,
np
.
ndarray
):
global_contracted_local_data
=
np
.
empty_like
(
contracted_local_data
)
new_mpi_dtype
=
self
.
_my_dtype_converter
.
to_mpi
(
new_dtype
)
self
.
comm
.
Allreduce
([
contracted_local_data
,
new_mpi_dtype
],
[
global_contracted_local_data
,
new_mpi_dtype
],
op
=
mpi_op
)
else
:
global_contracted_local_data
=
self
.
comm
.
allreduce
(
contracted_local_data
,
op
=
mpi_op
)
new_dist_strategy
=
'not'
else
:
new_dist_strategy
=
parent
.
distribution_strategy
global_contracted_local_data
=
contracted_local_data
if
new_shape
==
():
result
=
contracted_local_data
result
=
global_
contracted_local_data
else
:
# try to store the result in a distributed_data_object with the
# distribution_strategy as parent
...
...
@@ -538,12 +549,12 @@ class _slicing_distributor(distributor):
# Contracting (4, 4) to (4,).
# (4, 4) was distributed (1, 4)...(1, 4)
# (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
if
result
.
local_shape
!=
contracted_local_data
.
shape
:
if
result
.
local_shape
!=
global_
contracted_local_data
.
shape
:
result
=
parent
.
copy_empty
(
local_shape
=
contracted_local_data
.
shape
,
local_shape
=
global_
contracted_local_data
.
shape
,
dtype
=
new_dtype
,
distribution_strategy
=
'freeform'
)
result
.
set_local_data
(
contracted_local_data
,
copy
=
False
)
result
.
set_local_data
(
global_
contracted_local_data
,
copy
=
False
)
return
result
...
...
test/test_nifty_mpi_data.py
View file @
f39083b1
...
...
@@ -1747,7 +1747,7 @@ class Test_axis(unittest.TestCase):
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
all_datatypes
[
1
:],
[(
0
,),
(
1
,),
(
6
,
6
),
(
5
,
5
,
5
)],
[(
0
,),
(
1
,),
(
6
,
6
),
(
4
,
4
,
3
)],
all_distribution_strategies
,
[
None
,
0
,
(
1
,
),
(
0
,
1
)]),
testcase_func_name
=
custom_name_func
)
...
...
@@ -1779,7 +1779,7 @@ class Test_axis(unittest.TestCase):
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
all_datatypes
[
1
:],
[(
5
,
5
,
5
),
(
4
,
0
,
3
)],
[(
4
,
4
,
3
),
(
4
,
0
,
3
)],
all_distribution_strategies
,
[(
0
,
1
),
(
1
,
2
)]),
testcase_func_name
=
custom_name_func
)
def
test_axis_with_functions_for_many_dimentions
(
self
,
function
,
dtype
,
...
...
test/test_nifty_spaces.py
View file @
f39083b1
...
...
@@ -31,19 +31,19 @@ from nifty.operators.nifty_operators import power_operator
available
=
[]
try
:
from
nifty
import
lm_space
from
nifty
import
lm_space
except
ImportError
:
pass
else
:
available
+=
[
'lm_space'
]
try
:
from
nifty
import
gl_space
from
nifty
import
gl_space
except
ImportError
:
pass
else
:
available
+=
[
'gl_space'
]
try
:
from
nifty
import
hp_space
from
nifty
import
hp_space
except
ImportError
:
pass
else
:
...
...
@@ -1364,7 +1364,7 @@ class Test_axis(unittest.TestCase):
[
None
,
(
0
,)],
DATAMODELS
[
'point_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_
bi
nary_operations
(
self
,
name
,
num
,
op
,
axis
,
datamodel
):
def
test_
u
nary_operations
(
self
,
name
,
num
,
op
,
axis
,
datamodel
):
s
=
generate_space_with_size
(
name
,
np
.
prod
(
num
),
datamodel
=
datamodel
)
d
=
generate_data
(
s
)
a
=
d
.
get_full_data
()
...
...
@@ -1375,4 +1375,6 @@ class Test_axis(unittest.TestCase):
getattr
(
np
,
op
)(
a
,
axis
=
axis
),
decimal
=
4
)
if
name
in
[
'rg_space'
]:
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
0
,
1
)),
getattr
(
np
,
op
)(
a
,
axis
=
(
0
,
1
)),
decimal
=
4
)
\ No newline at end of file
getattr
(
np
,
op
)(
a
,
axis
=
(
0
,
1
)),
decimal
=
4
)
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
1
,)),
getattr
(
np
,
op
)(
a
,
axis
=
(
1
,)),
decimal
=
4
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment