Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
D2O
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
22
Issues
22
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ift
D2O
Commits
28cf619a
Commit
28cf619a
authored
Sep 12, 2016
by
theos
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added d2o.arange in factory_methods.py
parent
005fbbb8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
5 deletions
+103
-5
d2o/__init__.py
d2o/__init__.py
+3
-1
d2o/distributor_factory.py
d2o/distributor_factory.py
+20
-1
d2o/factory_methods.py
d2o/factory_methods.py
+58
-0
d2o/version.py
d2o/version.py
+1
-1
test/test_distributed_data_object.py
test/test_distributed_data_object.py
+21
-2
No files found.
d2o/__init__.py
View file @
28cf619a
...
...
@@ -25,3 +25,5 @@ from distributed_data_object import distributed_data_object
from
d2o_librarian
import
d2o_librarian
from
strategies
import
STRATEGIES
from
factory_methods
import
*
d2o/distributor_factory.py
View file @
28cf619a
...
...
@@ -2003,11 +2003,20 @@ class _slicing_distributor(distributor):
else
:
return
'not'
def
get_local_arange
(
self
,
global_start
,
global_step
):
local_offset
=
self
.
local_start
*
global_step
local_start
=
global_start
+
local_offset
local_stop
=
local_start
+
self
.
local_length
*
global_step
return
np
.
arange
(
local_start
,
local_stop
,
global_step
,
dtype
=
self
.
dtype
)
def
_equal_slicer
(
comm
,
global_shape
):
rank
=
comm
.
rank
size
=
comm
.
size
global_shape
=
tuple
(
int
(
x
)
for
x
in
global_shape
)
global_length
=
global_shape
[
0
]
# compute the smallest number of rows the node will get
local_length
=
global_length
//
size
...
...
@@ -2028,6 +2037,9 @@ def _equal_slicer(comm, global_shape):
def
_freeform_slicer
(
comm
,
local_shape
):
rank
=
comm
.
rank
size
=
comm
.
size
local_shape
=
tuple
(
int
(
x
)
for
x
in
local_shape
)
# Check that all but the first dimensions of local_shape are the same
local_sub_shape
=
local_shape
[
1
:]
local_sub_shape_list
=
comm
.
allgather
(
local_sub_shape
)
...
...
@@ -2052,6 +2064,8 @@ def _freeform_slicer(comm, local_shape):
if
'pyfftw'
in
gdi
:
def
_fftw_slicer
(
comm
,
global_shape
):
global_shape
=
tuple
(
int
(
x
)
for
x
in
global_shape
)
if
gc
[
'mpi_module'
]
!=
'MPI'
:
comm
=
None
# pyfftw.local_size crashes if any of the entries of global_shape
...
...
@@ -2076,7 +2090,7 @@ class _not_distributor(distributor):
def
__init__
(
self
,
global_shape
,
dtype
,
comm
,
*
args
,
**
kwargs
):
self
.
comm
=
comm
self
.
dtype
=
dtype
self
.
global_shape
=
global_shape
self
.
global_shape
=
tuple
(
int
(
x
)
for
x
in
global_shape
)
self
.
local_shape
=
self
.
global_shape
self
.
distribution_strategy
=
'not'
...
...
@@ -2340,3 +2354,8 @@ class _not_distributor(distributor):
def
get_axes_local_distribution_strategy
(
self
,
axes
):
return
'not'
def
get_local_arange
(
self
,
global_start
,
global_step
):
global_stop
=
global_start
+
self
.
global_shape
[
0
]
*
global_step
return
np
.
arange
(
global_start
,
global_stop
,
global_step
,
dtype
=
self
.
dtype
)
d2o/factory_methods.py
0 → 100644
View file @
28cf619a
# -*- coding: utf-8 -*-
import
numpy
as
np
from
d2o.config
import
configuration
as
gc
from
distributed_data_object
import
distributed_data_object
from
strategies
import
STRATEGIES
__all__
=
[
'arange'
]
def
arange
(
start
,
stop
=
None
,
step
=
None
,
dtype
=
np
.
int
,
distribution_strategy
=
gc
[
'default_distribution_strategy'
]):
# Check if the distribution_strategy is a global type one
if
distribution_strategy
not
in
STRATEGIES
[
'global'
]:
raise
ValueError
(
"ERROR: distribution_strategy must be a global one."
)
# parse the start/stop/step/dtype input
if
step
is
None
:
step
=
1
else
:
step
=
int
(
step
)
if
step
<
1
:
raise
ValueError
(
"ERROR: positive step size needed."
)
dtype
=
np
.
dtype
(
dtype
)
if
stop
is
not
None
:
try
:
stop
=
int
(
stop
)
except
(
TypeError
):
raise
ValueError
(
"ERROR: no valid 'stop' found."
)
try
:
start
=
int
(
start
)
except
(
TypeError
):
raise
ValueError
(
"ERROR: no valid 'start' found."
)
else
:
try
:
stop
=
int
(
start
)
except
(
TypeError
):
raise
ValueError
(
"ERROR: no valid 'start' found."
)
start
=
0
# create the empty distributed_data_object
global_shape
=
(
np
.
ceil
(
1.
*
(
stop
-
start
)
/
step
),
)
obj
=
distributed_data_object
(
global_shape
=
global_shape
,
dtype
=
dtype
,
distribution_strategy
=
distribution_strategy
)
# fill obj with the local range-data
local_arange
=
obj
.
distributor
.
get_local_arange
(
global_start
=
start
,
global_step
=
step
)
obj
.
set_local_data
(
local_arange
,
copy
=
False
)
return
obj
d2o/version.py
View file @
28cf619a
...
...
@@ -20,4 +20,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__
=
'1.0.2'
\ No newline at end of file
__version__
=
'1.0.3'
test/test_distributed_data_object.py
View file @
28cf619a
...
...
@@ -30,7 +30,8 @@ import warnings
import
tempfile
from
d2o
import
distributed_data_object
,
\
STRATEGIES
STRATEGIES
,
\
arange
from
distutils.version
import
LooseVersion
as
lv
...
...
@@ -1906,7 +1907,8 @@ class Test_axis(unittest.TestCase):
else
:
if
axis
is
not
None
:
assert_raises
(
NotImplementedError
,
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
lambda
:
getattr
(
obj
,
function_pair
[
0
])(
axis
=
axis
))
else
:
if
global_shape
!=
(
0
,)
and
global_shape
!=
(
1
,):
...
...
@@ -1924,3 +1926,20 @@ class Test_axis(unittest.TestCase):
(
a
,
axis
=
axis
),
dims
=
global_shape
),
decimal
=
4
)
class
Test_arange
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
(
all_datatypes
[
1
:],
[(
11
,
None
,
None
),
(
1
,
23
,
None
),
(
2
,
20
,
2
),
(
2
,
21
,
2
)],
global_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_arange
(
self
,
dtype
,
sss
,
distribution_strategy
):
obj
=
arange
(
start
=
sss
[
0
],
stop
=
sss
[
1
],
step
=
sss
[
2
],
dtype
=
dtype
,
distribution_strategy
=
distribution_strategy
)
a
=
np
.
arange
(
start
=
sss
[
0
],
stop
=
sss
[
1
],
step
=
sss
[
2
],
dtype
=
dtype
)
assert_equal
(
obj
.
get_full_data
(),
a
)
assert_equal
(
obj
.
dtype
,
a
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment