diff --git a/d2o/__init__.py b/d2o/__init__.py index 73c68ce88b387b0fa7986cead211a3e03b5d4753..3803999556e234b97ba3c57780d1c8d3a5cc6bbb 100644 --- a/d2o/__init__.py +++ b/d2o/__init__.py @@ -24,4 +24,6 @@ from config import configuration from distributed_data_object import distributed_data_object from d2o_librarian import d2o_librarian -from strategies import STRATEGIES \ No newline at end of file +from strategies import STRATEGIES + +from factory_methods import * diff --git a/d2o/distributor_factory.py b/d2o/distributor_factory.py index 68822e19af96f54268277dbce5cae584646d156d..38e4ba423eb7274d0bc75cd863c2763f8f6143f0 100644 --- a/d2o/distributor_factory.py +++ b/d2o/distributor_factory.py @@ -2003,11 +2003,20 @@ class _slicing_distributor(distributor): else: return 'not' + def get_local_arange(self, global_start, global_step): + local_offset = self.local_start * global_step + local_start = global_start + local_offset + local_stop = local_start + self.local_length * global_step + return np.arange(local_start, local_stop, global_step, + dtype=self.dtype) + def _equal_slicer(comm, global_shape): rank = comm.rank size = comm.size + global_shape = tuple(int(x) for x in global_shape) + global_length = global_shape[0] # compute the smallest number of rows the node will get local_length = global_length // size @@ -2028,6 +2037,9 @@ def _equal_slicer(comm, global_shape): def _freeform_slicer(comm, local_shape): rank = comm.rank size = comm.size + + local_shape = tuple(int(x) for x in local_shape) + # Check that all but the first dimensions of local_shape are the same local_sub_shape = local_shape[1:] local_sub_shape_list = comm.allgather(local_sub_shape) @@ -2052,6 +2064,8 @@ def _freeform_slicer(comm, local_shape): if 'pyfftw' in gdi: def _fftw_slicer(comm, global_shape): + global_shape = tuple(int(x) for x in global_shape) + if gc['mpi_module'] != 'MPI': comm = None # pyfftw.local_size crashes if any of the entries of global_shape @@ -2076,7 +2090,7 @@ class _not_distributor(distributor): def __init__(self, global_shape, dtype, comm, *args, **kwargs): self.comm = comm self.dtype = dtype - self.global_shape = global_shape + self.global_shape = tuple(int(x) for x in global_shape) self.local_shape = self.global_shape self.distribution_strategy = 'not' @@ -2340,3 +2354,8 @@ class _not_distributor(distributor): def get_axes_local_distribution_strategy(self, axes): return 'not' + + def get_local_arange(self, global_start, global_step): + global_stop = global_start + self.global_shape[0]*global_step + return np.arange(global_start, global_stop, global_step, + dtype=self.dtype) diff --git a/d2o/factory_methods.py b/d2o/factory_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..294822e169263f758189bf3176f55df9290ce47b --- /dev/null +++ b/d2o/factory_methods.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +import numpy as np + +from d2o.config import configuration as gc + +from distributed_data_object import distributed_data_object + +from strategies import STRATEGIES + +__all__ = ['arange'] + + +def arange(start, stop=None, step=None, dtype=np.int, + distribution_strategy=gc['default_distribution_strategy']): + + # Check if the distribution_strategy is a global type one + if distribution_strategy not in STRATEGIES['global']: + raise ValueError("ERROR: distribution_strategy must be a global one.") + + # parse the start/stop/step/dtype input + if step is None: + step = 1 + else: + step = int(step) + if step < 1: + raise ValueError("ERROR: positive step size needed.") + + dtype = np.dtype(dtype) + + if stop is not None: + try: + stop = int(stop) + except(TypeError): + raise ValueError("ERROR: no valid 'stop' found.") + try: + start = int(start) + except(TypeError): + raise ValueError("ERROR: no valid 'start' found.") + else: + try: + stop = int(start) + except(TypeError): + raise ValueError("ERROR: no valid 'start' found.") + start = 0 + + # create the empty distributed_data_object + global_shape = (np.ceil(1.*(stop-start)/step), ) + obj = distributed_data_object(global_shape=global_shape, + dtype=dtype, + distribution_strategy=distribution_strategy) + + # fill obj with the local range-data + local_arange = obj.distributor.get_local_arange(global_start=start, + global_step=step) + obj.set_local_data(local_arange, copy=False) + + return obj diff --git a/d2o/version.py b/d2o/version.py index 6425e312a4c1ae8cc739c62e67ee5bc0a1ed6aaf..0abc626a8220caffdb47c86a6cc571291507b5cd 100644 --- a/d2o/version.py +++ b/d2o/version.py @@ -20,4 +20,4 @@ # 1) we don't load dependencies by storing it in __init__.py # 2) we can import it in setup.py for the same reason # 3) we can import it into your module module -__version__ = '1.0.2' \ No newline at end of file +__version__ = '1.0.3' diff --git a/test/test_distributed_data_object.py b/test/test_distributed_data_object.py index 54056137ddce7d2bd9106b6f0102df2bd3c272ca..4ba08f4f72f175b388278d8b12127e2d0d6432d6 100644 --- a/test/test_distributed_data_object.py +++ b/test/test_distributed_data_object.py @@ -30,7 +30,8 @@ import warnings import tempfile from d2o import distributed_data_object,\ - STRATEGIES + STRATEGIES,\ + arange from distutils.version import LooseVersion as lv @@ -1906,7 +1907,8 @@ class Test_axis(unittest.TestCase): else: if axis is not None: assert_raises(NotImplementedError, - lambda: getattr(obj, function_pair[0])(axis=axis)) + lambda: getattr(obj, + function_pair[0])(axis=axis)) else: if global_shape != (0,) and global_shape != (1,): @@ -1924,3 +1926,20 @@ class Test_axis(unittest.TestCase): (a, axis=axis), dims=global_shape), decimal=4) + + +class Test_arange(unittest.TestCase): + @parameterized.expand( + itertools.product(all_datatypes[1:], + [(11, None, None), + (1, 23, None), + (2, 20, 2), + (2, 21, 2)], + global_distribution_strategies), + testcase_func_name=custom_name_func) + def test_arange(self, dtype, sss, distribution_strategy): + obj = arange(start=sss[0], stop=sss[1], step=sss[2], dtype=dtype, + distribution_strategy=distribution_strategy) + a = np.arange(start=sss[0], stop=sss[1], step=sss[2], dtype=dtype) + assert_equal(obj.get_full_data(), a) + assert_equal(obj.dtype, a.dtype)