Commit fb7e53de authored by Theo Steininger's avatar Theo Steininger

Added functionality for empty-shape d2o's.

parent b8f6f2d9
......@@ -1226,6 +1226,9 @@ class distributed_data_object(Versionable, object):
if axis is not None:
raise NotImplementedError("ERROR: argmin doesn't support axis "
"keyword")
if self.shape == ():
return 0
if 0 in self.local_shape:
local_argmin = np.nan
local_argmin_value = np.nan
......@@ -1260,6 +1263,9 @@ class distributed_data_object(Versionable, object):
if axis is not None:
raise NotImplementedError("ERROR: argmax doesn't support axis "
"keyword")
if self.shape == ():
return 0
if 0 in self.local_shape:
local_argmax = np.nan
local_argmax_value = -np.inf
......@@ -1291,7 +1297,8 @@ class distributed_data_object(Versionable, object):
See Also:
argmin, argmax, argmax_nonflat
"""
if self.shape == ():
return (0,)
return np.unravel_index(self.argmin(axis=axis), self.shape)
def argmax_nonflat(self, axis=None):
......@@ -1300,6 +1307,8 @@ class distributed_data_object(Versionable, object):
See Also:
argmin, argmax, argmin_nonflat
"""
if self.shape == ():
return (0,)
return np.unravel_index(self.argmax(axis=axis), self.shape)
def conjugate(self):
......
......@@ -81,7 +81,7 @@ class _distributor_factory(object):
if expensive_checks:
# Check that all nodes got the same distribution_strategy
strat_list = comm.allgather(distribution_strategy)
if all(x == strat_list[0] for x in strat_list) == False:
if not all(x == strat_list[0] for x in strat_list):
raise ValueError(about_cstring(
"ERROR: The distribution-strategy must be the same on " +
"all nodes!"))
......@@ -135,7 +135,7 @@ class _distributor_factory(object):
dtype = np.dtype(dtype)
if expensive_checks:
dtype_list = comm.allgather(dtype)
if all(x == dtype_list[0] for x in dtype_list) == False:
if not all(x == dtype_list[0] for x in dtype_list):
raise ValueError(about_cstring(
"ERROR: The given dtype must be the same on all nodes!"))
return_dict['dtype'] = dtype
......@@ -145,17 +145,19 @@ class _distributor_factory(object):
if distribution_strategy in STRATEGIES['global']:
if dset is not None:
global_shape = dset.shape
elif global_data is not None and np.isscalar(global_data) == False:
elif global_data is not None and not np.isscalar(global_data):
global_shape = global_data.shape
elif global_shape is not None:
global_shape = tuple(global_shape)
elif global_data is not None:
global_shape = ()
else:
raise ValueError(about_cstring(
"ERROR: Neither non-0-dimensional global_data nor " +
"ERROR: Neither global_data nor " +
"global_shape nor hdf5 file supplied!"))
if global_shape == ():
raise ValueError(about_cstring(
"ERROR: global_shape == () is not a valid shape!"))
# if global_shape == ():
# raise ValueError(about_cstring(
# "ERROR: global_shape == () is not a valid shape!"))
if expensive_checks:
global_shape_list = comm.allgather(global_shape)
......@@ -170,7 +172,7 @@ class _distributor_factory(object):
elif distribution_strategy in ['freeform']:
if isinstance(global_data, distributed_data_object):
local_shape = global_data.local_shape
elif local_data is not None and np.isscalar(local_data) == False:
elif local_data is not None and not np.isscalar(local_data):
local_shape = local_data.shape
elif local_shape is not None:
local_shape = tuple(local_shape)
......@@ -240,6 +242,11 @@ class _distributor_factory(object):
comm=comm,
**kwargs)
if parsed_kwargs.get('global_shape') == ():
distribution_strategy = 'not'
about_infos_cprint("WARNING: Distribution strategy was set to "
"'not' because of global_shape == ()")
hashed_kwargs = self.hash_arguments(distribution_strategy,
**parsed_kwargs)
# check if the distributors has already been produced in the past
......@@ -441,6 +448,8 @@ class distributor(object):
i += 1
def bincount(self, obj, length, weights=None, axis=None):
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj.get_local_data(copy=False)
# this implementation fits all distribution strategies where the
# axes of the global array correspond to the axes of the local data
......@@ -2240,7 +2249,7 @@ class _not_distributor(distributor):
if isinstance(data_object, distributed_data_object):
result_data = data_object.get_full_data()
else:
result_data = np.array(data_object)[:]
result_data = np.array(data_object)
try:
result_data = result_data.reshape(self.global_shape)
except ValueError:
......
......@@ -115,7 +115,11 @@ def custom_name_func(testcase_func, param_num, param):
def generate_data(global_shape, dtype, distribution_strategy,
strictly_positive=False):
if distribution_strategy in global_distribution_strategies:
if global_shape == ():
obj = distributed_data_object(global_shape=(), global_data=42.,
distribution_strategy='not')
global_a = np.array(42)
elif distribution_strategy in global_distribution_strategies:
a = np.arange(np.prod(global_shape))
a -= np.prod(global_shape) // 2
......@@ -250,6 +254,10 @@ class Test_Globaltype_Initialization(unittest.TestCase):
(2, 2), np.dtype('int')],
[None, (10, 10), None,
(10, 10), np.dtype('float64')],
[1., None, None,
(), np.dtype('float64')],
[None, (), None,
(), np.dtype('float64')],
], global_distribution_strategies),
testcase_func_name=custom_name_func)
def test_special_init_cases(self,
......@@ -269,7 +277,7 @@ class Test_Globaltype_Initialization(unittest.TestCase):
###############################################################################
if FOUND['h5py'] == True:
if FOUND['h5py']:
@parameterized.expand(itertools.product(hdf5_test_paths,
hdf5_distribution_strategies),
testcase_func_name=custom_name_func)
......@@ -289,8 +297,6 @@ class Test_Globaltype_Initialization(unittest.TestCase):
itertools.product(
[(None, None, None, None, None),
(None, None, np.int_, None, None),
(None, (), np.dtype('int'), None, None),
(1, None, None, None, None),
(None, None, None, np.array([1, 2, 3]), (3,)),
(None, None, np.int_, None, (3,))],
global_distribution_strategies),
......@@ -1507,7 +1513,7 @@ class Test_contractions(unittest.TestCase):
@parameterized.expand(
itertools.product([np.dtype('int'), np.dtype('float'),
np.dtype('complex')],
[(0,), (4, 4)],
[(), (0,), (4, 4)],
all_distribution_strategies),
testcase_func_name=custom_name_func)
def test_vdot(self, dtype, global_shape, distribution_strategy):
......@@ -1669,7 +1675,7 @@ class Test_special_methods(unittest.TestCase):
###############################################################################
@parameterized.expand(
itertools.product([(4,), (8, 8), (0, 4), (4, 0, 8)],
itertools.product([(), (4,), (8, 8), (0, 4), (4, 0, 8)],
all_distribution_strategies),
testcase_func_name=custom_name_func)
def test_flatten(self, global_shape, distribution_strategy):
......@@ -1678,7 +1684,7 @@ class Test_special_methods(unittest.TestCase):
distribution_strategy)
assert_equal(obj.flatten().get_full_data(), a.flatten())
p = obj.flatten(inplace=True)
if np.prod(global_shape) != 0:
if np.prod(global_shape) != 0 and global_shape != ():
p[0] = 2222
assert_equal(obj[(0,) * len(global_shape)], 2222)
......@@ -1729,7 +1735,7 @@ class Test_special_methods(unittest.TestCase):
###############################################################################
if FOUND['h5py'] == True:
if FOUND['h5py']:
class Test_load_save(unittest.TestCase):
@parameterized.expand(
......@@ -1747,7 +1753,7 @@ if FOUND['h5py'] == True:
path = os.path.join(tempfile.gettempdir(),
'temp_hdf5_file.hdf5')
if size > 1 and FOUND['h5py_parallel'] == False:
if size > 1 and not FOUND['h5py_parallel']:
assert_raises(RuntimeError,
lambda: obj.save(alias=alias, path=path))
else:
......@@ -1812,7 +1818,7 @@ class Test_axis(unittest.TestCase):
'min', 'amin', 'nanmin', 'argmin', 'max', 'amax',
'nanmax', 'argmax'],
all_datatypes[1:],
[(1, ), (2, 3)],
[(), (1,), (2, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
......@@ -1826,15 +1832,20 @@ class Test_axis(unittest.TestCase):
assert_raises(NotImplementedError, lambda: getattr(obj, function)
(axis=axis))
else:
if global_shape != (1,):
if global_shape == (2, 3):
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
else:
elif global_shape == (1,):
if axis in [None, 0, (0,)]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
else:
if axis in [None]:
assert_almost_equal(getattr(obj, function)(axis=axis),
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
......@@ -1891,7 +1902,7 @@ class Test_axis(unittest.TestCase):
itertools.product([('argmin_nonflat', 'argmin'),
('argmax_nonflat', 'argmax')],
all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)],
[(), (0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies,
[None, (1, ), (1, 2)]),
testcase_func_name=custom_name_func)
......@@ -1900,25 +1911,24 @@ class Test_axis(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
print (a, obj)
if 0 in global_shape:
assert_raises(ValueError,
lambda: getattr(obj, function_pair[0])(axis=axis))
else:
if axis is not None:
if axis is not None and global_shape != ():
assert_raises(NotImplementedError,
lambda: getattr(obj,
function_pair[0])(axis=axis))
else:
if global_shape != (0,) and global_shape != (1,):
if len(global_shape) > 1:
assert_almost_equal(
getattr(obj, function_pair[0])(axis=axis),
np.unravel_index(getattr(np, function_pair[1])
(a, axis=axis),
dims=global_shape),
decimal=4)
else:
elif len(global_shape) == 1:
assert_almost_equal(getattr(obj, function_pair[0])
(axis=axis),
np.unravel_index(
......@@ -1926,6 +1936,9 @@ class Test_axis(unittest.TestCase):
(a, axis=axis),
dims=global_shape),
decimal=4)
else:
assert_almost_equal(getattr(obj, function_pair[0])
(axis=axis), (0,))
class Test_arange(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment