Commit f5a9d3ac authored by theos's avatar theos
Browse files

Merge branch 'add_axis_keyword_to_d2o' of gitlab.mpcdf.mpg.de:ift/NIFTy into...

Merge branch 'add_axis_keyword_to_d2o' of gitlab.mpcdf.mpg.de:ift/NIFTy into add_axis_keyword_to_d2o
parents f38220f4 8cf39e8f
Pipeline #1912 skipped
# -*- coding: utf-8 -*-
import numbers as numbers
import numpy as np
......@@ -1175,9 +1176,13 @@ class distributed_data_object(object):
def std(self, axis=None):
""" Returns the standard deviation of the d2o's elements. """
return np.sqrt(self.var(axis=axis))
var = self.var(axis=axis)
if isinstance(var, numbers.Number):
return np.sqrt(var)
else:
return var.apply_scalar_function(np.sqrt)
def argmin(self):
def argmin(self, axis=None):
""" Returns the (flat) index of the d2o's smallest value.
See Also:
......@@ -1187,6 +1192,9 @@ class distributed_data_object(object):
if 0 in self.shape:
raise ValueError(
"ERROR: attempt to get argmin of an empty object")
if axis is not None:
raise NotImplementedError("ERROR: argmin doesn't support axis "
"keyword")
if 0 in self.local_shape:
local_argmin = np.nan
local_argmin_value = np.nan
......@@ -1208,7 +1216,7 @@ class distributed_data_object(object):
order=['value', 'index'])
return np.int(local_argmin_list[0][1])
def argmax(self):
def argmax(self, axis=None):
""" Returns the (flat) index of the d2o's biggest value.
See Also:
......@@ -1218,6 +1226,9 @@ class distributed_data_object(object):
if 0 in self.shape:
raise ValueError(
"ERROR: attempt to get argmax of an empty object")
if axis is not None:
raise NotImplementedError("ERROR: argmax doesn't support axis "
"keyword")
if 0 in self.local_shape:
local_argmax = np.nan
local_argmax_value = np.nan
......@@ -1238,23 +1249,22 @@ class distributed_data_object(object):
order=['value', 'index'])
return np.int(local_argmax_list[0][1])
def argmin_nonflat(self):
def argmin_nonflat(self, axis=None):
""" Returns the unraveld index of the d2o's smallest value.
See Also:
argmin, argmax, argmax_nonflat
"""
return np.unravel_index(self.argmin(), self.shape)
return np.unravel_index(self.argmin(axis=axis), self.shape)
def argmax_nonflat(self):
def argmax_nonflat(self, axis=None):
""" Returns the unraveld index of the d2o's biggest value.
See Also:
argmin, argmax, argmin_nonflat
"""
return np.unravel_index(self.argmax(), self.shape)
return np.unravel_index(self.argmax(axis=axis), self.shape)
def conjugate(self):
""" Returns the element-wise complex conjugate. """
......@@ -1284,7 +1294,15 @@ class distributed_data_object(object):
about.warnings.cprint(
"WARNING: The current implementation of median is very expensive!")
median = np.median(self.get_full_data(), axis=axis, **kwargs)
return median
if isinstance(median, numbers.Number):
return median
else:
x = self.copy_empty(global_shape=median.shape,
dtype=median.dtype,
distribution_strategy='not')
x.set_local_data(median)
return x
def _is_helper(self, function):
""" _is_helper is used for functions like isreal, isinf, isfinite,...
......
......@@ -892,7 +892,7 @@ class point_space(space):
def apply_scalar_function(self, x, function, inplace=False):
return x.apply_scalar_function(function, inplace=inplace)
def unary_operation(self, x, op='None', **kwargs):
def unary_operation(self, x, op='None', axis=None, **kwargs):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
......@@ -903,21 +903,21 @@ class point_space(space):
'abs': lambda y: getattr(y, '__abs__')(),
'real': lambda y: getattr(y, 'real'),
'imag': lambda y: getattr(y, 'imag'),
'nanmin': lambda y: getattr(y, 'nanmin')(),
'amin': lambda y: getattr(y, 'amin')(),
'nanmax': lambda y: getattr(y, 'nanmax')(),
'amax': lambda y: getattr(y, 'amax')(),
'median': lambda y: getattr(y, 'median')(),
'mean': lambda y: getattr(y, 'mean')(),
'std': lambda y: getattr(y, 'std')(),
'var': lambda y: getattr(y, 'var')(),
'argmin': lambda y: getattr(y, 'argmin_nonflat')(),
'argmin_flat': lambda y: getattr(y, 'argmin')(),
'argmax': lambda y: getattr(y, 'argmax_nonflat')(),
'argmax_flat': lambda y: getattr(y, 'argmax')(),
'nanmin': lambda y: getattr(y, 'nanmin')(axis=axis),
'amin': lambda y: getattr(y, 'amin')(axis=axis),
'nanmax': lambda y: getattr(y, 'nanmax')(axis=axis),
'amax': lambda y: getattr(y, 'amax')(axis=axis),
'median': lambda y: getattr(y, 'median')(axis=axis),
'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(axis=axis),
'argmin': lambda y: getattr(y, 'argmin')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(axis=axis),
'argmax': lambda y: getattr(y, 'argmax')(axis=axis),
'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(),
'prod': lambda y: getattr(y, 'prod')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis),
'prod': lambda y: getattr(y, 'prod')(axis=axis),
'unique': lambda y: getattr(y, 'unique')(),
'copy': lambda y: getattr(y, 'copy')(),
'copy_empty': lambda y: getattr(y, 'copy_empty')(),
......@@ -925,8 +925,8 @@ class point_space(space):
'isinf': lambda y: getattr(y, 'isinf')(),
'isfinite': lambda y: getattr(y, 'isfinite')(),
'nan_to_num': lambda y: getattr(y, 'nan_to_num')(),
'all': lambda y: getattr(y, 'all')(),
'any': lambda y: getattr(y, 'any')(),
'all': lambda y: getattr(y, 'all')(axis=axis),
'any': lambda y: getattr(y, 'any')(axis=axis),
'None': lambda y: y}
return translation[op](x, **kwargs)
......@@ -2747,10 +2747,10 @@ class field(object):
"""
if split:
return self._unary_helper(self.get_val(), op='argmin',
return self._unary_helper(self.get_val(), op='argmin_nonflat',
**kwargs)
else:
return self._unary_helper(self.get_val(), op='argmin_flat',
return self._unary_helper(self.get_val(), op='argmin',
**kwargs)
def argmax(self, split=True, **kwargs):
......@@ -2776,10 +2776,10 @@ class field(object):
"""
if split:
return self._unary_helper(self.get_val(), op='argmax',
return self._unary_helper(self.get_val(), op='argmax_nonflat',
**kwargs)
else:
return self._unary_helper(self.get_val(), op='argmax_flat',
return self._unary_helper(self.get_val(), op='argmax',
**kwargs)
# TODO: Implement the full range of unary and binary operotions
......
......@@ -1738,3 +1738,64 @@ if FOUND['h5py'] == True:
# Todo: Assert that data is copied, when copy flag is set
# Todo: Assert that set, get and injection work, if there is different data
# on the nodes
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(0,), (1,), (6, 6), (5, 5, 5)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
def test_axis_with_functions(self, function, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
else:
if global_shape != (0,) and global_shape != (1,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
else:
if function in ['min', 'amin', 'nanmin', 'max',
'amax', 'nanmax']:
assert_raises(ValueError)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(5, 5, 5), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_with_functions_for_many_dimentions(self, function, dtype,
global_shape,
distribution_strategy,
axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if function in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
else:
if function in ['min', 'amin', 'nanmin', 'max', 'amax', 'nanmax']\
and np.prod(global_shape) == 0:
assert_raises(ValueError)
else:
assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), decimal=4)
......@@ -27,6 +27,8 @@ from nifty.nifty_power_indices import power_indices
from nifty.nifty_utilities import _hermitianize_inverter as \
hermitianize_inverter
from nifty.operators.nifty_operators import power_operator
available = []
try:
from nifty import lm_space
......@@ -130,9 +132,9 @@ if HP_DISTRIBUTION_STRATEGIES != []:
unary_operations = ['pos', 'neg', 'abs', 'real', 'imag', 'nanmin', 'amin',
'nanmax', 'amax', 'median', 'mean', 'std', 'var', 'argmin',
'argmin_flat', 'argmax', 'argmax_flat', 'conjugate', 'sum',
'prod', 'unique', 'copy', 'copy_empty', 'isnan', 'isinf',
'isfinite', 'nan_to_num', 'all', 'any', 'None']
'argmin_nonflat', 'argmax', 'argmax_nonflat', 'conjugate',
'sum', 'prod', 'unique', 'copy', 'copy_empty', 'isnan',
'isinf', 'isfinite', 'nan_to_num', 'all', 'any', 'None']
binary_operations = ['add', 'radd', 'iadd', 'sub', 'rsub', 'isub', 'mul',
'rmul', 'imul', 'div', 'rdiv', 'idiv', 'pow', 'rpow',
......@@ -178,6 +180,22 @@ def generate_space(name):
return space_dict[name]
def generate_space_with_size(name, num, datamodel='fftw'):
space_dict = {'space': space(),
'point_space': point_space(num, datamodel=datamodel),
'rg_space': rg_space((num, num), datamodel=datamodel),
}
if 'lm_space' in available:
space_dict['lm_space'] = lm_space(mmax=num, lmax=num,
datamodel=datamodel)
if 'hp_space' in available:
space_dict['hp_space'] = hp_space(num, datamodel=datamodel)
if 'gl_space' in available:
space_dict['gl_space'] = gl_space(nlat=num, nlon=num,
datamodel=datamodel)
return space_dict[name]
def generate_data(space):
a = np.arange(space.get_dim()).reshape(space.get_shape())
data = space.cast(a)
......@@ -1334,4 +1352,27 @@ class Test_Lm_Space(unittest.TestCase):
print all_spaces
print generate_space('rg_space')
\ No newline at end of file
print generate_space('rg_space')
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [8, 16],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'amax', 'nanmax', 'argmax',
'argmax_nonflat'],
[None, (0,)],
DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_binary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel)
d = generate_data(s)
a = d.get_full_data()
if op in ['argmin', 'argmin_nonflat', 'argmax', 'argmax_nonflat']:
assert_raises(NotImplementedError)
else:
assert_almost_equal(s.unary_operation(d, op, axis=axis),
getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment