Commit fb7e53de authored by Theo Steininger's avatar Theo Steininger

Added functionality for empty-shape d2o's.

parent b8f6f2d9
...@@ -1226,6 +1226,9 @@ class distributed_data_object(Versionable, object): ...@@ -1226,6 +1226,9 @@ class distributed_data_object(Versionable, object):
if axis is not None: if axis is not None:
raise NotImplementedError("ERROR: argmin doesn't support axis " raise NotImplementedError("ERROR: argmin doesn't support axis "
"keyword") "keyword")
if self.shape == ():
return 0
if 0 in self.local_shape: if 0 in self.local_shape:
local_argmin = np.nan local_argmin = np.nan
local_argmin_value = np.nan local_argmin_value = np.nan
...@@ -1260,6 +1263,9 @@ class distributed_data_object(Versionable, object): ...@@ -1260,6 +1263,9 @@ class distributed_data_object(Versionable, object):
if axis is not None: if axis is not None:
raise NotImplementedError("ERROR: argmax doesn't support axis " raise NotImplementedError("ERROR: argmax doesn't support axis "
"keyword") "keyword")
if self.shape == ():
return 0
if 0 in self.local_shape: if 0 in self.local_shape:
local_argmax = np.nan local_argmax = np.nan
local_argmax_value = -np.inf local_argmax_value = -np.inf
...@@ -1291,7 +1297,8 @@ class distributed_data_object(Versionable, object): ...@@ -1291,7 +1297,8 @@ class distributed_data_object(Versionable, object):
See Also: See Also:
argmin, argmax, argmax_nonflat argmin, argmax, argmax_nonflat
""" """
if self.shape == ():
return (0,)
return np.unravel_index(self.argmin(axis=axis), self.shape) return np.unravel_index(self.argmin(axis=axis), self.shape)
def argmax_nonflat(self, axis=None): def argmax_nonflat(self, axis=None):
...@@ -1300,6 +1307,8 @@ class distributed_data_object(Versionable, object): ...@@ -1300,6 +1307,8 @@ class distributed_data_object(Versionable, object):
See Also: See Also:
argmin, argmax, argmin_nonflat argmin, argmax, argmin_nonflat
""" """
if self.shape == ():
return (0,)
return np.unravel_index(self.argmax(axis=axis), self.shape) return np.unravel_index(self.argmax(axis=axis), self.shape)
def conjugate(self): def conjugate(self):
......
...@@ -81,7 +81,7 @@ class _distributor_factory(object): ...@@ -81,7 +81,7 @@ class _distributor_factory(object):
if expensive_checks: if expensive_checks:
# Check that all nodes got the same distribution_strategy # Check that all nodes got the same distribution_strategy
strat_list = comm.allgather(distribution_strategy) strat_list = comm.allgather(distribution_strategy)
if all(x == strat_list[0] for x in strat_list) == False: if not all(x == strat_list[0] for x in strat_list):
raise ValueError(about_cstring( raise ValueError(about_cstring(
"ERROR: The distribution-strategy must be the same on " + "ERROR: The distribution-strategy must be the same on " +
"all nodes!")) "all nodes!"))
...@@ -135,7 +135,7 @@ class _distributor_factory(object): ...@@ -135,7 +135,7 @@ class _distributor_factory(object):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if expensive_checks: if expensive_checks:
dtype_list = comm.allgather(dtype) dtype_list = comm.allgather(dtype)
if all(x == dtype_list[0] for x in dtype_list) == False: if not all(x == dtype_list[0] for x in dtype_list):
raise ValueError(about_cstring( raise ValueError(about_cstring(
"ERROR: The given dtype must be the same on all nodes!")) "ERROR: The given dtype must be the same on all nodes!"))
return_dict['dtype'] = dtype return_dict['dtype'] = dtype
...@@ -145,17 +145,19 @@ class _distributor_factory(object): ...@@ -145,17 +145,19 @@ class _distributor_factory(object):
if distribution_strategy in STRATEGIES['global']: if distribution_strategy in STRATEGIES['global']:
if dset is not None: if dset is not None:
global_shape = dset.shape global_shape = dset.shape
elif global_data is not None and np.isscalar(global_data) == False: elif global_data is not None and not np.isscalar(global_data):
global_shape = global_data.shape global_shape = global_data.shape
elif global_shape is not None: elif global_shape is not None:
global_shape = tuple(global_shape) global_shape = tuple(global_shape)
elif global_data is not None:
global_shape = ()
else: else:
raise ValueError(about_cstring( raise ValueError(about_cstring(
"ERROR: Neither non-0-dimensional global_data nor " + "ERROR: Neither global_data nor " +
"global_shape nor hdf5 file supplied!")) "global_shape nor hdf5 file supplied!"))
if global_shape == (): # if global_shape == ():
raise ValueError(about_cstring( # raise ValueError(about_cstring(
"ERROR: global_shape == () is not a valid shape!")) # "ERROR: global_shape == () is not a valid shape!"))
if expensive_checks: if expensive_checks:
global_shape_list = comm.allgather(global_shape) global_shape_list = comm.allgather(global_shape)
...@@ -170,7 +172,7 @@ class _distributor_factory(object): ...@@ -170,7 +172,7 @@ class _distributor_factory(object):
elif distribution_strategy in ['freeform']: elif distribution_strategy in ['freeform']:
if isinstance(global_data, distributed_data_object): if isinstance(global_data, distributed_data_object):
local_shape = global_data.local_shape local_shape = global_data.local_shape
elif local_data is not None and np.isscalar(local_data) == False: elif local_data is not None and not np.isscalar(local_data):
local_shape = local_data.shape local_shape = local_data.shape
elif local_shape is not None: elif local_shape is not None:
local_shape = tuple(local_shape) local_shape = tuple(local_shape)
...@@ -240,6 +242,11 @@ class _distributor_factory(object): ...@@ -240,6 +242,11 @@ class _distributor_factory(object):
comm=comm, comm=comm,
**kwargs) **kwargs)
if parsed_kwargs.get('global_shape') == ():
distribution_strategy = 'not'
about_infos_cprint("WARNING: Distribution strategy was set to "
"'not' because of global_shape == ()")
hashed_kwargs = self.hash_arguments(distribution_strategy, hashed_kwargs = self.hash_arguments(distribution_strategy,
**parsed_kwargs) **parsed_kwargs)
# check if the distributors has already been produced in the past # check if the distributors has already been produced in the past
...@@ -441,6 +448,8 @@ class distributor(object): ...@@ -441,6 +448,8 @@ class distributor(object):
i += 1 i += 1
def bincount(self, obj, length, weights=None, axis=None): def bincount(self, obj, length, weights=None, axis=None):
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj.get_local_data(copy=False) data = obj.get_local_data(copy=False)
# this implementation fits all distribution strategies where the # this implementation fits all distribution strategies where the
# axes of the global array correspond to the axes of the local data # axes of the global array correspond to the axes of the local data
...@@ -2240,7 +2249,7 @@ class _not_distributor(distributor): ...@@ -2240,7 +2249,7 @@ class _not_distributor(distributor):
if isinstance(data_object, distributed_data_object): if isinstance(data_object, distributed_data_object):
result_data = data_object.get_full_data() result_data = data_object.get_full_data()
else: else:
result_data = np.array(data_object)[:] result_data = np.array(data_object)
try: try:
result_data = result_data.reshape(self.global_shape) result_data = result_data.reshape(self.global_shape)
except ValueError: except ValueError:
......
...@@ -115,7 +115,11 @@ def custom_name_func(testcase_func, param_num, param): ...@@ -115,7 +115,11 @@ def custom_name_func(testcase_func, param_num, param):
def generate_data(global_shape, dtype, distribution_strategy, def generate_data(global_shape, dtype, distribution_strategy,
strictly_positive=False): strictly_positive=False):
if distribution_strategy in global_distribution_strategies: if global_shape == ():
obj = distributed_data_object(global_shape=(), global_data=42.,
distribution_strategy='not')
global_a = np.array(42)
elif distribution_strategy in global_distribution_strategies:
a = np.arange(np.prod(global_shape)) a = np.arange(np.prod(global_shape))
a -= np.prod(global_shape) // 2 a -= np.prod(global_shape) // 2
...@@ -250,6 +254,10 @@ class Test_Globaltype_Initialization(unittest.TestCase): ...@@ -250,6 +254,10 @@ class Test_Globaltype_Initialization(unittest.TestCase):
(2, 2), np.dtype('int')], (2, 2), np.dtype('int')],
[None, (10, 10), None, [None, (10, 10), None,
(10, 10), np.dtype('float64')], (10, 10), np.dtype('float64')],
[1., None, None,
(), np.dtype('float64')],
[None, (), None,
(), np.dtype('float64')],
], global_distribution_strategies), ], global_distribution_strategies),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_special_init_cases(self, def test_special_init_cases(self,
...@@ -269,7 +277,7 @@ class Test_Globaltype_Initialization(unittest.TestCase): ...@@ -269,7 +277,7 @@ class Test_Globaltype_Initialization(unittest.TestCase):
############################################################################### ###############################################################################
if FOUND['h5py'] == True: if FOUND['h5py']:
@parameterized.expand(itertools.product(hdf5_test_paths, @parameterized.expand(itertools.product(hdf5_test_paths,
hdf5_distribution_strategies), hdf5_distribution_strategies),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
...@@ -289,8 +297,6 @@ class Test_Globaltype_Initialization(unittest.TestCase): ...@@ -289,8 +297,6 @@ class Test_Globaltype_Initialization(unittest.TestCase):
itertools.product( itertools.product(
[(None, None, None, None, None), [(None, None, None, None, None),
(None, None, np.int_, None, None), (None, None, np.int_, None, None),
(None, (), np.dtype('int'), None, None),
(1, None, None, None, None),
(None, None, None, np.array([1, 2, 3]), (3,)), (None, None, None, np.array([1, 2, 3]), (3,)),
(None, None, np.int_, None, (3,))], (None, None, np.int_, None, (3,))],
global_distribution_strategies), global_distribution_strategies),
...@@ -1507,7 +1513,7 @@ class Test_contractions(unittest.TestCase): ...@@ -1507,7 +1513,7 @@ class Test_contractions(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
itertools.product([np.dtype('int'), np.dtype('float'), itertools.product([np.dtype('int'), np.dtype('float'),
np.dtype('complex')], np.dtype('complex')],
[(0,), (4, 4)], [(), (0,), (4, 4)],
all_distribution_strategies), all_distribution_strategies),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_vdot(self, dtype, global_shape, distribution_strategy): def test_vdot(self, dtype, global_shape, distribution_strategy):
...@@ -1669,7 +1675,7 @@ class Test_special_methods(unittest.TestCase): ...@@ -1669,7 +1675,7 @@ class Test_special_methods(unittest.TestCase):
############################################################################### ###############################################################################
@parameterized.expand( @parameterized.expand(
itertools.product([(4,), (8, 8), (0, 4), (4, 0, 8)], itertools.product([(), (4,), (8, 8), (0, 4), (4, 0, 8)],
all_distribution_strategies), all_distribution_strategies),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
def test_flatten(self, global_shape, distribution_strategy): def test_flatten(self, global_shape, distribution_strategy):
...@@ -1678,7 +1684,7 @@ class Test_special_methods(unittest.TestCase): ...@@ -1678,7 +1684,7 @@ class Test_special_methods(unittest.TestCase):
distribution_strategy) distribution_strategy)
assert_equal(obj.flatten().get_full_data(), a.flatten()) assert_equal(obj.flatten().get_full_data(), a.flatten())
p = obj.flatten(inplace=True) p = obj.flatten(inplace=True)
if np.prod(global_shape) != 0: if np.prod(global_shape) != 0 and global_shape != ():
p[0] = 2222 p[0] = 2222
assert_equal(obj[(0,) * len(global_shape)], 2222) assert_equal(obj[(0,) * len(global_shape)], 2222)
...@@ -1729,7 +1735,7 @@ class Test_special_methods(unittest.TestCase): ...@@ -1729,7 +1735,7 @@ class Test_special_methods(unittest.TestCase):
############################################################################### ###############################################################################
if FOUND['h5py'] == True: if FOUND['h5py']:
class Test_load_save(unittest.TestCase): class Test_load_save(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
...@@ -1747,7 +1753,7 @@ if FOUND['h5py'] == True: ...@@ -1747,7 +1753,7 @@ if FOUND['h5py'] == True:
path = os.path.join(tempfile.gettempdir(), path = os.path.join(tempfile.gettempdir(),
'temp_hdf5_file.hdf5') 'temp_hdf5_file.hdf5')
if size > 1 and FOUND['h5py_parallel'] == False: if size > 1 and not FOUND['h5py_parallel']:
assert_raises(RuntimeError, assert_raises(RuntimeError,
lambda: obj.save(alias=alias, path=path)) lambda: obj.save(alias=alias, path=path))
else: else:
...@@ -1812,7 +1818,7 @@ class Test_axis(unittest.TestCase): ...@@ -1812,7 +1818,7 @@ class Test_axis(unittest.TestCase):
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax', 'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'], 'nanmax', 'argmax'],
all_datatypes[1:], all_datatypes[1:],
[(1, ), (2, 3)], [(), (1,), (2, 3)],
all_distribution_strategies, all_distribution_strategies,
[None, 0, (1, ), (0, 1)]), [None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
...@@ -1826,15 +1832,20 @@ class Test_axis(unittest.TestCase): ...@@ -1826,15 +1832,20 @@ class Test_axis(unittest.TestCase):
assert_raises(NotImplementedError, lambda: getattr(obj, function) assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis)) (axis=axis))
else: else:
if global_shape != (1,): if global_shape == (2, 3):
assert_almost_equal(getattr(obj, function)(axis=axis), assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis), getattr(np, function)(a, axis=axis),
decimal=4) decimal=4)
else: elif global_shape == (1,):
if axis in [None, 0, (0,)]: if axis in [None, 0, (0,)]:
assert_almost_equal(getattr(obj, function)(axis=axis), assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis), getattr(np, function)(a, axis=axis),
decimal=4) decimal=4)
else:
if axis in [None]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand( @parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any', itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
...@@ -1891,7 +1902,7 @@ class Test_axis(unittest.TestCase): ...@@ -1891,7 +1902,7 @@ class Test_axis(unittest.TestCase):
itertools.product([('argmin_nonflat', 'argmin'), itertools.product([('argmin_nonflat', 'argmin'),
('argmax_nonflat', 'argmax')], ('argmax_nonflat', 'argmax')],
all_datatypes[1:], all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)], [(), (0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies, all_distribution_strategies,
[None, (1, ), (1, 2)]), [None, (1, ), (1, 2)]),
testcase_func_name=custom_name_func) testcase_func_name=custom_name_func)
...@@ -1900,25 +1911,24 @@ class Test_axis(unittest.TestCase): ...@@ -1900,25 +1911,24 @@ class Test_axis(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype, (a, obj) = generate_data(global_shape, dtype,
distribution_strategy, distribution_strategy,
strictly_positive=True) strictly_positive=True)
print (a, obj)
if 0 in global_shape: if 0 in global_shape:
assert_raises(ValueError, assert_raises(ValueError,
lambda: getattr(obj, function_pair[0])(axis=axis)) lambda: getattr(obj, function_pair[0])(axis=axis))
else: else:
if axis is not None: if axis is not None and global_shape != ():
assert_raises(NotImplementedError, assert_raises(NotImplementedError,
lambda: getattr(obj, lambda: getattr(obj,
function_pair[0])(axis=axis)) function_pair[0])(axis=axis))
else: else:
if global_shape != (0,) and global_shape != (1,): if len(global_shape) > 1:
assert_almost_equal( assert_almost_equal(
getattr(obj, function_pair[0])(axis=axis), getattr(obj, function_pair[0])(axis=axis),
np.unravel_index(getattr(np, function_pair[1]) np.unravel_index(getattr(np, function_pair[1])
(a, axis=axis), (a, axis=axis),
dims=global_shape), dims=global_shape),
decimal=4) decimal=4)
else: elif len(global_shape) == 1:
assert_almost_equal(getattr(obj, function_pair[0]) assert_almost_equal(getattr(obj, function_pair[0])
(axis=axis), (axis=axis),
np.unravel_index( np.unravel_index(
...@@ -1926,6 +1936,9 @@ class Test_axis(unittest.TestCase): ...@@ -1926,6 +1936,9 @@ class Test_axis(unittest.TestCase):
(a, axis=axis), (a, axis=axis),
dims=global_shape), dims=global_shape),
decimal=4) decimal=4)
else:
assert_almost_equal(getattr(obj, function_pair[0])
(axis=axis), (0,))
class Test_arange(unittest.TestCase): class Test_arange(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment