Commit 154a163f authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'minor-fixes-for-axis' into 'add_axis_keyword_to_d2o'

Minor fixes for axis



See merge request !10
parents f5a9d3ac 5599041c
Pipeline #2029 skipped
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numbers
import numpy as np import numpy as np
from nifty import about from nifty import about
...@@ -11,7 +10,7 @@ def cast_axis_to_tuple(axis): ...@@ -11,7 +10,7 @@ def cast_axis_to_tuple(axis):
try: try:
axis = tuple([int(item) for item in axis]) axis = tuple([int(item) for item in axis])
except(TypeError): except(TypeError):
if isinstance(axis, numbers.Number): if np.isscalar(axis):
axis = (int(axis), ) axis = (int(axis), )
else: else:
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numbers as numbers
import numpy as np import numpy as np
from nifty.keepers import about,\ from nifty.keepers import about,\
...@@ -1177,7 +1175,7 @@ class distributed_data_object(object): ...@@ -1177,7 +1175,7 @@ class distributed_data_object(object):
def std(self, axis=None): def std(self, axis=None):
""" Returns the standard deviation of the d2o's elements. """ """ Returns the standard deviation of the d2o's elements. """
var = self.var(axis=axis) var = self.var(axis=axis)
if isinstance(var, numbers.Number): if np.isscalar(var):
return np.sqrt(var) return np.sqrt(var)
else: else:
return var.apply_scalar_function(np.sqrt) return var.apply_scalar_function(np.sqrt)
...@@ -1294,7 +1292,7 @@ class distributed_data_object(object): ...@@ -1294,7 +1292,7 @@ class distributed_data_object(object):
about.warnings.cprint( about.warnings.cprint(
"WARNING: The current implementation of median is very expensive!") "WARNING: The current implementation of median is very expensive!")
median = np.median(self.get_full_data(), axis=axis, **kwargs) median = np.median(self.get_full_data(), axis=axis, **kwargs)
if isinstance(median, numbers.Number): if np.isscalar(median):
return median return median
else: else:
x = self.copy_empty(global_shape=median.shape, x = self.copy_empty(global_shape=median.shape,
...@@ -1303,7 +1301,6 @@ class distributed_data_object(object): ...@@ -1303,7 +1301,6 @@ class distributed_data_object(object):
x.set_local_data(median) x.set_local_data(median)
return x return x
def _is_helper(self, function): def _is_helper(self, function):
""" _is_helper is used for functions like isreal, isinf, isfinite,... """ _is_helper is used for functions like isreal, isinf, isfinite,...
......
...@@ -517,14 +517,32 @@ class _slicing_distributor(distributor): ...@@ -517,14 +517,32 @@ class _slicing_distributor(distributor):
# check if additional contraction along the first axis must be done # check if additional contraction along the first axis must be done
if axis is None or 0 in axis: if axis is None or 0 in axis:
(mpi_op, bufferQ) = op_translate_dict[function] (mpi_op, bufferQ) = op_translate_dict[function]
contracted_local_data = self.comm.allreduce(contracted_local_data, # check if allreduce must be used instead of Allreduce
op=mpi_op) use_Uppercase = False
if bufferQ and isinstance(contracted_local_data, np.ndarray):
# MPI.MAX and MPI.MIN do not support complex data types
if not np.issubdtype(contracted_local_data.dtype,
np.complexfloating):
use_Uppercase = True
if use_Uppercase:
global_contracted_local_data = np.empty_like(
contracted_local_data)
new_mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
self.comm.Allreduce([contracted_local_data,
new_mpi_dtype],
[global_contracted_local_data,
new_mpi_dtype],
op=mpi_op)
else:
global_contracted_local_data = self.comm.allreduce(
contracted_local_data, op=mpi_op)
new_dist_strategy = 'not' new_dist_strategy = 'not'
else: else:
new_dist_strategy = parent.distribution_strategy new_dist_strategy = parent.distribution_strategy
global_contracted_local_data = contracted_local_data
if new_shape == (): if new_shape == ():
result = contracted_local_data result = global_contracted_local_data
else: else:
# try to store the result in a distributed_data_object with the # try to store the result in a distributed_data_object with the
# distribution_strategy as parent # distribution_strategy as parent
...@@ -538,12 +556,12 @@ class _slicing_distributor(distributor): ...@@ -538,12 +556,12 @@ class _slicing_distributor(distributor):
# Contracting (4, 4) to (4,). # Contracting (4, 4) to (4,).
# (4, 4) was distributed (1, 4)...(1, 4) # (4, 4) was distributed (1, 4)...(1, 4)
# (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()! # (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
if result.local_shape != contracted_local_data.shape: if result.local_shape != global_contracted_local_data.shape:
result = parent.copy_empty( result = parent.copy_empty(
local_shape=contracted_local_data.shape, local_shape=global_contracted_local_data.shape,
dtype=new_dtype, dtype=new_dtype,
distribution_strategy='freeform') distribution_strategy='freeform')
result.set_local_data(contracted_local_data, copy=False) result.set_local_data(global_contracted_local_data, copy=False)
return result return result
......
...@@ -8,10 +8,10 @@ from nifty.keepers import global_configuration as gc,\ ...@@ -8,10 +8,10 @@ from nifty.keepers import global_configuration as gc,\
MPI = gdi[gc['mpi_module']] MPI = gdi[gc['mpi_module']]
custom_NANMIN = MPI.Op.Create(lambda x, y, datatype: custom_NANMIN = MPI.Op.Create(lambda x, y, datatype:
np.nanmin(np.vstack(x, y), axis=0)) np.nanmin(np.vstack((x, y)), axis=0))
custom_NANMAX = MPI.Op.Create(lambda x, y, datatype: custom_NANMAX = MPI.Op.Create(lambda x, y, datatype:
np.nanmax(np.vstack(x, y), axis=0)) np.nanmax(np.vstack((x, y)), axis=0))
custom_UNIQUE = MPI.Op.Create(lambda x, y, datatype: custom_UNIQUE = MPI.Op.Create(lambda x, y, datatype:
np.unique(np.concatenate([x, y]))) np.unique(np.concatenate([x, y])))
...@@ -24,8 +24,8 @@ op_translate_dict[np.sum] = (MPI.SUM, True) ...@@ -24,8 +24,8 @@ op_translate_dict[np.sum] = (MPI.SUM, True)
op_translate_dict[np.prod] = (MPI.PROD, True) op_translate_dict[np.prod] = (MPI.PROD, True)
op_translate_dict[np.amin] = (MPI.MIN, True) op_translate_dict[np.amin] = (MPI.MIN, True)
op_translate_dict[np.amax] = (MPI.MAX, True) op_translate_dict[np.amax] = (MPI.MAX, True)
op_translate_dict[np.all] = (MPI.LAND, True) op_translate_dict[np.all] = (MPI.BAND, True)
op_translate_dict[np.any] = (MPI.LOR, True) op_translate_dict[np.any] = (MPI.BOR, True)
op_translate_dict[np.nanmin] = (custom_NANMIN, False) op_translate_dict[np.nanmin] = (custom_NANMIN, False)
op_translate_dict[np.nanmax] = (custom_NANMAX, False) op_translate_dict[np.nanmax] = (custom_NANMAX, False)
op_translate_dict[np.unique] = (custom_UNIQUE, False) op_translate_dict[np.unique] = (custom_UNIQUE, False)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numbers
import copy import copy
import numpy as np import numpy as np
...@@ -89,7 +88,7 @@ class Intracomm(Comm): ...@@ -89,7 +88,7 @@ class Intracomm(Comm):
return recvbuf return recvbuf
def allreduce(self, sendobj, op=SUM, **kwargs): def allreduce(self, sendobj, op=SUM, **kwargs):
if isinstance(sendobj, numbers.Number): if np.isscalar(sendobj):
return sendobj return sendobj
return copy.copy(sendobj) return copy.copy(sendobj)
......
...@@ -17,6 +17,8 @@ import nifty ...@@ -17,6 +17,8 @@ import nifty
from nifty.d2o import distributed_data_object,\ from nifty.d2o import distributed_data_object,\
STRATEGIES STRATEGIES
from distutils.version import LooseVersion as lv
FOUND = {} FOUND = {}
try: try:
import h5py import h5py
...@@ -1049,8 +1051,8 @@ class Test_set_data_via_injection(unittest.TestCase): ...@@ -1049,8 +1051,8 @@ class Test_set_data_via_injection(unittest.TestCase):
all_distribution_strategies all_distribution_strategies
), testcase_func_name=custom_name_func) ), testcase_func_name=custom_name_func)
def test_set_data_via_injection(self, (global_shape_1, slice_tuple_1, def test_set_data_via_injection(self, (global_shape_1, slice_tuple_1,
global_shape_2, slice_tuple_2), global_shape_2, slice_tuple_2),
distribution_strategy): distribution_strategy):
dtype = np.dtype('float') dtype = np.dtype('float')
(a, obj) = generate_data(global_shape_1, dtype, (a, obj) = generate_data(global_shape_1, dtype,
distribution_strategy) distribution_strategy)
...@@ -1059,8 +1061,8 @@ class Test_set_data_via_injection(unittest.TestCase): ...@@ -1059,8 +1061,8 @@ class Test_set_data_via_injection(unittest.TestCase):
distribution_strategy) distribution_strategy)
obj.set_data(to_key=slice_tuple_1, obj.set_data(to_key=slice_tuple_1,
data=p, data=p,
from_key=slice_tuple_2) from_key=slice_tuple_2)
a[slice_tuple_1] = b[slice_tuple_2] a[slice_tuple_1] = b[slice_tuple_2]
assert_equal(obj.get_full_data(), a) assert_equal(obj.get_full_data(), a)
...@@ -1601,9 +1603,9 @@ class Test_comparisons(unittest.TestCase): ...@@ -1601,9 +1603,9 @@ class Test_comparisons(unittest.TestCase):
class Test_special_methods(unittest.TestCase): class Test_special_methods(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
itertools.product(all_distribution_strategies, itertools.product(all_distribution_strategies,
all_distribution_strategies), all_distribution_strategies),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_bincount(self, distribution_strategy_1, distribution_strategy_2): def test_bincount(self, distribution_strategy_1, distribution_strategy_2):
global_shape = (10,) global_shape = (10,)
dtype = np.dtype('int') dtype = np.dtype('int')
...@@ -1742,12 +1744,11 @@ if FOUND['h5py'] == True: ...@@ -1742,12 +1744,11 @@ if FOUND['h5py'] == True:
class Test_axis(unittest.TestCase): class Test_axis(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'any', 'min', 'amin', 'nanmin', 'argmin', 'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'argmin_nonflat', 'max', 'amax', 'nanmax', 'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
'argmax', 'argmax_nonflat'],
all_datatypes[1:], all_datatypes[1:],
[(0,), (1,), (6, 6), (5, 5, 5)], [(0,), (1,), (6, 6), (4, 4, 3)],
all_distribution_strategies, all_distribution_strategies,
[None, 0, (1, ), (0, 1)]), [None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
...@@ -1774,12 +1775,11 @@ class Test_axis(unittest.TestCase): ...@@ -1774,12 +1775,11 @@ class Test_axis(unittest.TestCase):
decimal=4) decimal=4)
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'any', 'min', 'amin', 'nanmin', 'argmin', 'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'argmin_nonflat', 'max', 'amax', 'nanmax', 'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
'argmax', 'argmax_nonflat'],
all_datatypes[1:], all_datatypes[1:],
[(5, 5, 5), (4, 0, 3)], [(4, 4, 3), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]), all_distribution_strategies, [(0, 1), (1, 2)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_axis_with_functions_for_many_dimentions(self, function, dtype, def test_axis_with_functions_for_many_dimentions(self, function, dtype,
...@@ -1798,4 +1798,27 @@ class Test_axis(unittest.TestCase): ...@@ -1798,4 +1798,27 @@ class Test_axis(unittest.TestCase):
else: else:
assert_almost_equal(getattr(obj, function) assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(), (axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), decimal=4) getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_for_median(self, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if global_shape != (0,) and global_shape != (1,) and lv("1.9.0Z") < \
lv(np.__version__):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
...@@ -31,19 +31,19 @@ from nifty.operators.nifty_operators import power_operator ...@@ -31,19 +31,19 @@ from nifty.operators.nifty_operators import power_operator
available = [] available = []
try: try:
from nifty import lm_space from nifty import lm_space
except ImportError: except ImportError:
pass pass
else: else:
available += ['lm_space'] available += ['lm_space']
try: try:
from nifty import gl_space from nifty import gl_space
except ImportError: except ImportError:
pass pass
else: else:
available += ['gl_space'] available += ['gl_space']
try: try:
from nifty import hp_space from nifty import hp_space
except ImportError: except ImportError:
pass pass
else: else:
...@@ -1364,7 +1364,7 @@ class Test_axis(unittest.TestCase): ...@@ -1364,7 +1364,7 @@ class Test_axis(unittest.TestCase):
[None, (0,)], [None, (0,)],
DATAMODELS['point_space']), DATAMODELS['point_space']),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_binary_operations(self, name, num, op, axis, datamodel): def test_unary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel) s = generate_space_with_size(name, np.prod(num), datamodel=datamodel)
d = generate_data(s) d = generate_data(s)
a = d.get_full_data() a = d.get_full_data()
...@@ -1375,4 +1375,6 @@ class Test_axis(unittest.TestCase): ...@@ -1375,4 +1375,6 @@ class Test_axis(unittest.TestCase):
getattr(np, op)(a, axis=axis), decimal=4) getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']: if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)), assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4) getattr(np, op)(a, axis=(0, 1)), decimal=4)
\ No newline at end of file assert_almost_equal(s.unary_operation(d, op, axis=(1,)),
getattr(np, op)(a, axis=(1,)), decimal=4)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment