Commit 45820d90 authored by Theo Steininger's avatar Theo Steininger

Added int8 as possible dtype. Fixed tests with respect to this issue:...

Added int8 as possible dtype. Fixed tests with respect to this issue: https://github.com/numpy/numpy/issues/9290
parent 22e1d389
Pipeline #14847 failed with stage
in 3 minutes and 18 seconds
...@@ -750,8 +750,6 @@ class distributed_data_object(Loggable, Versionable, object): ...@@ -750,8 +750,6 @@ class distributed_data_object(Loggable, Versionable, object):
# use common datatype for self and other # use common datatype for self and other
new_dtype = np.dtype(np.find_common_type((self.dtype,), new_dtype = np.dtype(np.find_common_type((self.dtype,),
(result_data.dtype,))) (result_data.dtype,)))
if new_dtype == np.int8:
new_dtype == np.int16
temp_d2o = self.copy_empty(dtype=new_dtype) temp_d2o = self.copy_empty(dtype=new_dtype)
# write the new data into the return-distributed_data_object # write the new data into the return-distributed_data_object
......
...@@ -38,6 +38,7 @@ class _dtype_converter(object): ...@@ -38,6 +38,7 @@ class _dtype_converter(object):
# [, MPI_SIGNED_CHAR], # [, MPI_SIGNED_CHAR],
# [, MPI_UNSIGNED_CHAR], # [, MPI_UNSIGNED_CHAR],
[np.dtype('bool'), MPI.BYTE], [np.dtype('bool'), MPI.BYTE],
[np.dtype('int8'), MPI.BYTE],
[np.dtype('int16'), MPI.SHORT], [np.dtype('int16'), MPI.SHORT],
[np.dtype('uint16'), MPI.UNSIGNED_SHORT], [np.dtype('uint16'), MPI.UNSIGNED_SHORT],
[np.dtype('uint32'), MPI.UNSIGNED_INT], [np.dtype('uint32'), MPI.UNSIGNED_INT],
......
...@@ -76,7 +76,7 @@ np.random.seed(123) ...@@ -76,7 +76,7 @@ np.random.seed(123)
# np.int, np.int64, np.uint64, np.float32, np.float_, np.float, # np.int, np.int64, np.uint64, np.float32, np.float_, np.float,
# np.float64, np.float128, np.complex64, np.complex_, # np.float64, np.float128, np.complex64, np.complex_,
# np.complex, np.complex128] # np.complex, np.complex128]
all_datatypes = [np.dtype('bool'), np.dtype('int16'), np.dtype('uint16'), all_datatypes = [np.dtype('bool'), np.dtype('int8'), np.dtype('int16'), np.dtype('uint16'),
np.dtype('uint32'), np.dtype('int32'), np.dtype('int'), np.dtype('uint32'), np.dtype('int32'), np.dtype('int'),
np.dtype('int64'), np.dtype('uint'), np.dtype('uint64'), np.dtype('int64'), np.dtype('uint'), np.dtype('uint64'),
np.dtype('float32'), np.dtype( np.dtype('float32'), np.dtype(
...@@ -1557,6 +1557,9 @@ class Test_contractions(unittest.TestCase): ...@@ -1557,6 +1557,9 @@ class Test_contractions(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
distribution_strategy, distribution_strategy,
strictly_positive=True) strictly_positive=True)
if dtype is np.dtype('int8'):
a[a > 8] = 0
obj[obj > 8] = 0
assert_allclose(getattr(obj, function)(), getattr(np, function)(a), assert_allclose(getattr(obj, function)(), getattr(np, function)(a),
rtol=1e-4) rtol=1e-4)
...@@ -1833,6 +1836,9 @@ class Test_axis(unittest.TestCase): ...@@ -1833,6 +1836,9 @@ class Test_axis(unittest.TestCase):
if axis in [(1, ), (0, 1)] and global_shape == (0,): if axis in [(1, ), (0, 1)] and global_shape == (0,):
assert_raises(Exception, lambda: getattr(obj, function) assert_raises(Exception, lambda: getattr(obj, function)
(axis=axis)) (axis=axis))
elif function in ['all', 'any']:
assert_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis))
else: else:
assert_almost_equal(getattr(obj, function)(axis=axis), assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis), getattr(np, function)(a, axis=axis),
...@@ -1856,21 +1862,16 @@ class Test_axis(unittest.TestCase): ...@@ -1856,21 +1862,16 @@ class Test_axis(unittest.TestCase):
if function in ['argmin', 'argmax'] and axis is not None: if function in ['argmin', 'argmax'] and axis is not None:
assert_raises(NotImplementedError, lambda: getattr(obj, function) assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis)) (axis=axis))
else: elif (global_shape == (2, 3) or
if global_shape == (2, 3): (global_shape == (1,) and axis in [None, 0, (0,)]) or
(global_shape == () and axis is None)):
if function in ['all', 'any']:
assert_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis))
else:
assert_almost_equal(getattr(obj, function)(axis=axis), assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis), getattr(np, function)(a, axis=axis),
decimal=4) decimal=4)
elif global_shape == (1,):
if axis in [None, 0, (0,)]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
else:
if axis in [None]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
...@@ -1896,6 +1897,9 @@ class Test_axis(unittest.TestCase): ...@@ -1896,6 +1897,9 @@ class Test_axis(unittest.TestCase):
and 0 in global_shape: and 0 in global_shape:
assert_raises(ValueError, lambda: getattr(obj, function) assert_raises(ValueError, lambda: getattr(obj, function)
(axis=axis)) (axis=axis))
elif function in ['all', 'any']:
assert_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis))
else: else:
assert_almost_equal(getattr(obj, function)(axis=axis), assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis), getattr(np, function)(a, axis=axis),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment