Commit facb6103 authored by csongor's avatar csongor
Browse files

Fix median, add axis for argmin etc.

parent f399b0da
Pipeline #1815 skipped
......@@ -1182,7 +1182,7 @@ class distributed_data_object(object):
else:
return var.apply_scalar_function(np.sqrt)
def argmin(self):
def argmin(self, axis=None):
""" Returns the (flat) index of the d2o's smallest value.
See Also:
......@@ -1192,6 +1192,9 @@ class distributed_data_object(object):
if 0 in self.shape:
raise ValueError(
"ERROR: attempt to get argmin of an empty object")
if axis is not None:
raise NotImplementedError("ERROR: argmin doesn't support axis "
"keyword")
if 0 in self.local_shape:
local_argmin = np.nan
local_argmin_value = np.nan
......@@ -1213,7 +1216,7 @@ class distributed_data_object(object):
order=['value', 'index'])
return np.int(local_argmin_list[0][1])
def argmax(self):
def argmax(self, axis=None):
""" Returns the (flat) index of the d2o's biggest value.
See Also:
......@@ -1223,6 +1226,9 @@ class distributed_data_object(object):
if 0 in self.shape:
raise ValueError(
"ERROR: attempt to get argmax of an empty object")
if axis is not None:
raise NotImplementedError("ERROR: argmax doesn't support axis "
"keyword")
if 0 in self.local_shape:
local_argmax = np.nan
local_argmax_value = np.nan
......@@ -1243,23 +1249,22 @@ class distributed_data_object(object):
order=['value', 'index'])
return np.int(local_argmax_list[0][1])
def argmin_nonflat(self):
def argmin_nonflat(self, axis=None):
""" Returns the unraveld index of the d2o's smallest value.
See Also:
argmin, argmax, argmax_nonflat
"""
return np.unravel_index(self.argmin(), self.shape)
return np.unravel_index(self.argmin(axis=axis), self.shape)
def argmax_nonflat(self):
def argmax_nonflat(self, axis=None):
""" Returns the unraveld index of the d2o's biggest value.
See Also:
argmin, argmax, argmin_nonflat
"""
return np.unravel_index(self.argmax(), self.shape)
return np.unravel_index(self.argmax(axis=axis), self.shape)
def conjugate(self):
""" Returns the element-wise complex conjugate. """
......@@ -1292,9 +1297,12 @@ class distributed_data_object(object):
if isinstance(median, numbers.Number):
return median
else:
return distributed_data_object(global_data=median,
global_shape=median.shape,
distribution_strategy='not')
x = self.copy_empty(global_shape=median.shape,
dtype=median.dtype,
distribution_strategy='not')
x.set_local_data(median)
return x
def _is_helper(self, function):
""" _is_helper is used for functions like isreal, isinf, isfinite,...
......
......@@ -911,10 +911,10 @@ class point_space(space):
'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin': lambda y: getattr(y, 'argmin_nonflat')(),
'argmin_flat': lambda y: getattr(y, 'argmin')(),
'argmax': lambda y: getattr(y, 'argmax_nonflat')(),
'argmax_flat': lambda y: getattr(y, 'argmax')(),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(axis=axis),
'argmin': lambda y: getattr(y, 'argmin')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(axis=axis),
'argmax': lambda y: getattr(y, 'argmax')(axis=axis),
'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis),
'prod': lambda y: getattr(y, 'prod')(axis=axis),
......@@ -2747,10 +2747,10 @@ class field(object):
"""
if split:
return self._unary_helper(self.get_val(), op='argmin',
return self._unary_helper(self.get_val(), op='argmin_nonflat',
**kwargs)
else:
return self._unary_helper(self.get_val(), op='argmin_flat',
return self._unary_helper(self.get_val(), op='argmin',
**kwargs)
def argmax(self, split=True, **kwargs):
......@@ -2776,10 +2776,10 @@ class field(object):
"""
if split:
return self._unary_helper(self.get_val(), op='argmax',
return self._unary_helper(self.get_val(), op='argmax_nonflat',
**kwargs)
else:
return self._unary_helper(self.get_val(), op='argmax_flat',
return self._unary_helper(self.get_val(), op='argmax',
**kwargs)
# TODO: Implement the full range of unary and binary operotions
......
......@@ -1743,8 +1743,9 @@ class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'max',
'amax', 'nanmax'],
'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(0,), (1,), (6, 6), (5, 5, 5)],
all_distribution_strategies,
......@@ -1755,23 +1756,28 @@ class Test_axis(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if global_shape != (0,) and global_shape != (1,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis), decimal=4)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
else:
if function in ['min', 'amin', 'nanmin', 'max',
'amax', 'nanmax']:
assert_raises(ValueError)
if global_shape != (0,) and global_shape != (1,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
if function in ['min', 'amin', 'nanmin', 'max',
'amax', 'nanmax']:
assert_raises(ValueError)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'max',
'amax', 'nanmax'],
'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(5, 5, 5), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]),
......@@ -1783,10 +1789,13 @@ class Test_axis(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if function in ['min', 'amin', 'nanmin', 'max', 'amax', 'nanmax']\
and np.prod(global_shape) == 0:
assert_raises(ValueError)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
else:
assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), decimal=4)
if function in ['min', 'amin', 'nanmin', 'max', 'amax', 'nanmax']\
and np.prod(global_shape) == 0:
assert_raises(ValueError)
else:
assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), decimal=4)
......@@ -132,9 +132,9 @@ if HP_DISTRIBUTION_STRATEGIES != []:
unary_operations = ['pos', 'neg', 'abs', 'real', 'imag', 'nanmin', 'amin',
'nanmax', 'amax', 'median', 'mean', 'std', 'var', 'argmin',
'argmin_flat', 'argmax', 'argmax_flat', 'conjugate', 'sum',
'prod', 'unique', 'copy', 'copy_empty', 'isnan', 'isinf',
'isfinite', 'nan_to_num', 'all', 'any', 'None']
'argmin_nonflat', 'argmax', 'argmax_nonflat', 'conjugate',
'sum', 'prod', 'unique', 'copy', 'copy_empty', 'isnan',
'isinf', 'isfinite', 'nan_to_num', 'all', 'any', 'None']
binary_operations = ['add', 'radd', 'iadd', 'sub', 'rsub', 'isub', 'mul',
'rmul', 'imul', 'div', 'rdiv', 'idiv', 'pow', 'rpow',
......@@ -1358,15 +1358,21 @@ class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [8, 16],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'nanmin', 'nanmax'], [None, (0,)],
'any', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'amax', 'nanmax', 'argmax',
'argmax_nonflat'],
[None, (0,)],
DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_binary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel)
d = generate_data(s)
a = d.get_full_data()
assert_almost_equal(s.unary_operation(d, op, axis=axis),
getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4)
\ No newline at end of file
if op in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
else:
assert_almost_equal(s.unary_operation(d, op, axis=axis),
getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment