Commit 28cf619a authored by theos's avatar theos

Added d2o.arange in factory_methods.py

parent 005fbbb8
......@@ -25,3 +25,5 @@ from distributed_data_object import distributed_data_object
from d2o_librarian import d2o_librarian
from strategies import STRATEGIES
from factory_methods import *
......@@ -2003,11 +2003,20 @@ class _slicing_distributor(distributor):
else:
return 'not'
def get_local_arange(self, global_start, global_step):
local_offset = self.local_start * global_step
local_start = global_start + local_offset
local_stop = local_start + self.local_length * global_step
return np.arange(local_start, local_stop, global_step,
dtype=self.dtype)
def _equal_slicer(comm, global_shape):
rank = comm.rank
size = comm.size
global_shape = tuple(int(x) for x in global_shape)
global_length = global_shape[0]
# compute the smallest number of rows the node will get
local_length = global_length // size
......@@ -2028,6 +2037,9 @@ def _equal_slicer(comm, global_shape):
def _freeform_slicer(comm, local_shape):
rank = comm.rank
size = comm.size
local_shape = tuple(int(x) for x in local_shape)
# Check that all but the first dimensions of local_shape are the same
local_sub_shape = local_shape[1:]
local_sub_shape_list = comm.allgather(local_sub_shape)
......@@ -2052,6 +2064,8 @@ def _freeform_slicer(comm, local_shape):
if 'pyfftw' in gdi:
def _fftw_slicer(comm, global_shape):
global_shape = tuple(int(x) for x in global_shape)
if gc['mpi_module'] != 'MPI':
comm = None
# pyfftw.local_size crashes if any of the entries of global_shape
......@@ -2076,7 +2090,7 @@ class _not_distributor(distributor):
def __init__(self, global_shape, dtype, comm, *args, **kwargs):
self.comm = comm
self.dtype = dtype
self.global_shape = global_shape
self.global_shape = tuple(int(x) for x in global_shape)
self.local_shape = self.global_shape
self.distribution_strategy = 'not'
......@@ -2340,3 +2354,8 @@ class _not_distributor(distributor):
def get_axes_local_distribution_strategy(self, axes):
return 'not'
def get_local_arange(self, global_start, global_step):
global_stop = global_start + self.global_shape[0]*global_step
return np.arange(global_start, global_stop, global_step,
dtype=self.dtype)
# -*- coding: utf-8 -*-
import numpy as np
from d2o.config import configuration as gc
from distributed_data_object import distributed_data_object
from strategies import STRATEGIES
__all__ = ['arange']
def arange(start, stop=None, step=None, dtype=np.int,
distribution_strategy=gc['default_distribution_strategy']):
# Check if the distribution_strategy is a global type one
if distribution_strategy not in STRATEGIES['global']:
raise ValueError("ERROR: distribution_strategy must be a global one.")
# parse the start/stop/step/dtype input
if step is None:
step = 1
else:
step = int(step)
if step < 1:
raise ValueError("ERROR: positive step size needed.")
dtype = np.dtype(dtype)
if stop is not None:
try:
stop = int(stop)
except(TypeError):
raise ValueError("ERROR: no valid 'stop' found.")
try:
start = int(start)
except(TypeError):
raise ValueError("ERROR: no valid 'start' found.")
else:
try:
stop = int(start)
except(TypeError):
raise ValueError("ERROR: no valid 'start' found.")
start = 0
# create the empty distributed_data_object
global_shape = (np.ceil(1.*(stop-start)/step), )
obj = distributed_data_object(global_shape=global_shape,
dtype=dtype,
distribution_strategy=distribution_strategy)
# fill obj with the local range-data
local_arange = obj.distributor.get_local_arange(global_start=start,
global_step=step)
obj.set_local_data(local_arange, copy=False)
return obj
......@@ -20,4 +20,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__ = '1.0.2'
\ No newline at end of file
__version__ = '1.0.3'
......@@ -30,7 +30,8 @@ import warnings
import tempfile
from d2o import distributed_data_object,\
STRATEGIES
STRATEGIES,\
arange
from distutils.version import LooseVersion as lv
......@@ -1906,7 +1907,8 @@ class Test_axis(unittest.TestCase):
else:
if axis is not None:
assert_raises(NotImplementedError,
lambda: getattr(obj, function_pair[0])(axis=axis))
lambda: getattr(obj,
function_pair[0])(axis=axis))
else:
if global_shape != (0,) and global_shape != (1,):
......@@ -1924,3 +1926,20 @@ class Test_axis(unittest.TestCase):
(a, axis=axis),
dims=global_shape),
decimal=4)
class Test_arange(unittest.TestCase):
@parameterized.expand(
itertools.product(all_datatypes[1:],
[(11, None, None),
(1, 23, None),
(2, 20, 2),
(2, 21, 2)],
global_distribution_strategies),
testcase_func_name=custom_name_func)
def test_arange(self, dtype, sss, distribution_strategy):
obj = arange(start=sss[0], stop=sss[1], step=sss[2], dtype=dtype,
distribution_strategy=distribution_strategy)
a = np.arange(start=sss[0], stop=sss[1], step=sss[2], dtype=dtype)
assert_equal(obj.get_full_data(), a)
assert_equal(obj.dtype, a.dtype)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment