Commit 45820d90 authored by Theo Steininger's avatar Theo Steininger

Added int8 as possible dtype. Fixed tests with respect to this issue:...

Added int8 as possible dtype. Fixed tests with respect to this issue: https://github.com/numpy/numpy/issues/9290
parent 22e1d389
Pipeline #14847 failed with stage
in 3 minutes and 18 seconds
......@@ -750,8 +750,6 @@ class distributed_data_object(Loggable, Versionable, object):
# use common datatype for self and other
new_dtype = np.dtype(np.find_common_type((self.dtype,),
(result_data.dtype,)))
if new_dtype == np.int8:
new_dtype == np.int16
temp_d2o = self.copy_empty(dtype=new_dtype)
# write the new data into the return-distributed_data_object
......
......@@ -38,6 +38,7 @@ class _dtype_converter(object):
# [, MPI_SIGNED_CHAR],
# [, MPI_UNSIGNED_CHAR],
[np.dtype('bool'), MPI.BYTE],
[np.dtype('int8'), MPI.BYTE],
[np.dtype('int16'), MPI.SHORT],
[np.dtype('uint16'), MPI.UNSIGNED_SHORT],
[np.dtype('uint32'), MPI.UNSIGNED_INT],
......
......@@ -76,7 +76,7 @@ np.random.seed(123)
# np.int, np.int64, np.uint64, np.float32, np.float_, np.float,
# np.float64, np.float128, np.complex64, np.complex_,
# np.complex, np.complex128]
all_datatypes = [np.dtype('bool'), np.dtype('int16'), np.dtype('uint16'),
all_datatypes = [np.dtype('bool'), np.dtype('int8'), np.dtype('int16'), np.dtype('uint16'),
np.dtype('uint32'), np.dtype('int32'), np.dtype('int'),
np.dtype('int64'), np.dtype('uint'), np.dtype('uint64'),
np.dtype('float32'), np.dtype(
......@@ -1557,6 +1557,9 @@ class Test_contractions(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if dtype is np.dtype('int8'):
a[a > 8] = 0
obj[obj > 8] = 0
assert_allclose(getattr(obj, function)(), getattr(np, function)(a),
rtol=1e-4)
......@@ -1833,6 +1836,9 @@ class Test_axis(unittest.TestCase):
if axis in [(1, ), (0, 1)] and global_shape == (0,):
assert_raises(Exception, lambda: getattr(obj, function)
(axis=axis))
elif function in ['all', 'any']:
assert_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis))
else:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
......@@ -1856,18 +1862,13 @@ class Test_axis(unittest.TestCase):
if function in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis))
elif (global_shape == (2, 3) or
(global_shape == (1,) and axis in [None, 0, (0,)]) or
(global_shape == () and axis is None)):
if function in ['all', 'any']:
assert_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis))
else:
if global_shape == (2, 3):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
elif global_shape == (1,):
if axis in [None, 0, (0,)]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
else:
if axis in [None]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
......@@ -1896,6 +1897,9 @@ class Test_axis(unittest.TestCase):
and 0 in global_shape:
assert_raises(ValueError, lambda: getattr(obj, function)
(axis=axis))
elif function in ['all', 'any']:
assert_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis))
else:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment