Commit f39083b1 authored by csongor's avatar csongor
Browse files

fix test dimentions and some max and min cases

parent 08e80a04
Pipeline #1963 skipped
......@@ -517,14 +517,25 @@ class _slicing_distributor(distributor):
# check if additional contraction along the first axis must be done
if axis is None or 0 in axis:
(mpi_op, bufferQ) = op_translate_dict[function]
contracted_local_data = self.comm.allreduce(contracted_local_data,
op=mpi_op)
if bufferQ and isinstance(contracted_local_data, np.ndarray):
global_contracted_local_data = np.empty_like(
contracted_local_data)
new_mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
self.comm.Allreduce([contracted_local_data,
new_mpi_dtype],
[global_contracted_local_data,
new_mpi_dtype],
op=mpi_op)
else:
global_contracted_local_data = self.comm.allreduce(
contracted_local_data, op=mpi_op)
new_dist_strategy = 'not'
else:
new_dist_strategy = parent.distribution_strategy
global_contracted_local_data = contracted_local_data
if new_shape == ():
result = contracted_local_data
result = global_contracted_local_data
else:
# try to store the result in a distributed_data_object with the
# distribution_strategy as parent
......@@ -538,12 +549,12 @@ class _slicing_distributor(distributor):
# Contracting (4, 4) to (4,).
# (4, 4) was distributed (1, 4)...(1, 4)
# (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
if result.local_shape != contracted_local_data.shape:
if result.local_shape != global_contracted_local_data.shape:
result = parent.copy_empty(
local_shape=contracted_local_data.shape,
local_shape=global_contracted_local_data.shape,
dtype=new_dtype,
distribution_strategy='freeform')
result.set_local_data(contracted_local_data, copy=False)
result.set_local_data(global_contracted_local_data, copy=False)
return result
......
......@@ -1747,7 +1747,7 @@ class Test_axis(unittest.TestCase):
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(0,), (1,), (6, 6), (5, 5, 5)],
[(0,), (1,), (6, 6), (4, 4, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
......@@ -1779,7 +1779,7 @@ class Test_axis(unittest.TestCase):
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(5, 5, 5), (4, 0, 3)],
[(4, 4, 3), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_with_functions_for_many_dimentions(self, function, dtype,
......
......@@ -31,19 +31,19 @@ from nifty.operators.nifty_operators import power_operator
available = []
try:
from nifty import lm_space
from nifty import lm_space
except ImportError:
pass
else:
available += ['lm_space']
try:
from nifty import gl_space
from nifty import gl_space
except ImportError:
pass
else:
available += ['gl_space']
try:
from nifty import hp_space
from nifty import hp_space
except ImportError:
pass
else:
......@@ -1364,7 +1364,7 @@ class Test_axis(unittest.TestCase):
[None, (0,)],
DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_binary_operations(self, name, num, op, axis, datamodel):
def test_unary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel)
d = generate_data(s)
a = d.get_full_data()
......@@ -1375,4 +1375,6 @@ class Test_axis(unittest.TestCase):
getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4)
\ No newline at end of file
getattr(np, op)(a, axis=(0, 1)), decimal=4)
assert_almost_equal(s.unary_operation(d, op, axis=(1,)),
getattr(np, op)(a, axis=(1,)), decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment