Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
ift
D2O
Commits
65098d3f
Commit
65098d3f
authored
Jul 04, 2016
by
theos
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added a simple search-sorted functionality.
Fixed 'mean' method.
parent
50113262
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
5 deletions
+55
-5
d2o/cast_axis_to_tuple.py
d2o/cast_axis_to_tuple.py
+5
-2
d2o/distributed_data_object.py
d2o/distributed_data_object.py
+25
-1
d2o/distributor_factory.py
d2o/distributor_factory.py
+24
-1
d2o/version.py
d2o/version.py
+1
-1
No files found.
d2o/cast_axis_to_tuple.py
View file @
65098d3f
...
...
@@ -19,15 +19,18 @@
import
numpy
as
np
def
cast_axis_to_tuple
(
axis
):
def
cast_axis_to_tuple
(
axis
,
length
):
if
axis
is
None
:
return
None
try
:
axis
=
tuple
(
[
int
(
item
)
for
item
in
axis
]
)
axis
=
tuple
(
int
(
item
)
for
item
in
axis
)
except
(
TypeError
):
if
np
.
isscalar
(
axis
):
axis
=
(
int
(
axis
),
)
else
:
raise
TypeError
(
"ERROR: Could not convert axis-input to tuple of ints"
)
# shift negative indices to positive ones
axis
=
tuple
(
item
if
(
item
>=
0
)
else
(
item
+
length
)
for
item
in
axis
)
return
axis
d2o/distributed_data_object.py
View file @
65098d3f
...
...
@@ -1172,7 +1172,7 @@ class distributed_data_object(object):
def
mean
(
self
,
axis
=
None
,
**
kwargs
):
# infer, which axes will be collapsed
axis
=
cast_axis_to_tuple
(
axis
)
axis
=
cast_axis_to_tuple
(
axis
,
length
=
len
(
self
.
shape
)
)
if
axis
is
None
:
collapsed_shapes
=
self
.
shape
else
:
...
...
@@ -1821,6 +1821,30 @@ class distributed_data_object(object):
return
self
.
distributor
.
cumsum
(
parent
=
self
,
axis
=
axis
)
def
searchsorted
(
self
,
v
,
side
=
'left'
):
""" Find indices where elements should be inserted to maintain order.
Find the indices into a sorted array `a` such that, if the
corresponding elements in `v` were inserted before the indices, the
order of `a` would be preserved.
Parameters
----------
a : 1-D array_like
Input array. If `sorter` is None, then it must be sorted in
ascending order, otherwise `sorter` must be an array of indices
that sort it.
v : array_like
Values to insert into `a`.
side : {'left', 'right'}, optional
If 'left', the index of the first suitable location found is given.
If 'right', return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `a`).
"""
return
self
.
distributor
.
searchsorted
(
obj
=
self
,
v
=
v
,
side
=
side
)
def
save
(
self
,
alias
,
path
=
None
,
overwriteQ
=
True
):
""" Saves the distributed_data_object to disk utilizing h5py.
...
...
d2o/distributor_factory.py
View file @
65098d3f
...
...
@@ -588,7 +588,7 @@ class _slicing_distributor(distributor):
return
parent
.
copy
()
old_shape
=
parent
.
shape
axis
=
cast_axis_to_tuple
(
axis
)
axis
=
cast_axis_to_tuple
(
axis
,
length
=
len
(
self
.
global_shape
)
)
if
axis
is
None
:
new_shape
=
()
else
:
...
...
@@ -1551,6 +1551,25 @@ class _slicing_distributor(distributor):
return
result_d2o
def
searchsorted
(
self
,
obj
,
v
,
side
=
'left'
):
a
=
obj
.
get_local_data
(
copy
=
False
)
local_searched
=
np
.
searchsorted
(
a
=
a
,
v
=
v
,
side
=
side
)
global_searched
=
np
.
empty_like
(
local_searched
)
if
side
is
'left'
:
op
=
MPI
.
MAX
elif
side
is
'right'
:
op
=
MPI
.
MIN
else
:
raise
ValueError
self
.
comm
.
Allreduce
([
local_searched
,
MPI
.
INT
],
[
global_searched
,
MPI
.
INT
],
op
=
op
)
if
global_searched
.
shape
==
():
global_searched
=
global_searched
[()]
return
global_searched
def
_sliceify
(
self
,
inp
):
sliceified
=
[]
result
=
[]
...
...
@@ -2042,6 +2061,10 @@ class _not_distributor(distributor):
result_d2o
.
set_local_data
(
local_cumsum
,
copy
=
False
)
return
result_d2o
def
searchsorted
(
self
,
obj
,
v
,
side
=
'left'
):
a
=
obj
.
get_local_data
(
copy
=
False
)
return
np
.
searchsorted
(
a
=
a
,
v
=
v
,
side
=
side
)
if
'h5py'
in
gdi
:
def
save_data
(
self
,
data
,
alias
,
path
=
None
,
overwriteQ
=
True
):
comm
=
self
.
comm
...
...
d2o/version.py
View file @
65098d3f
...
...
@@ -20,4 +20,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__
=
'1.0.0'
\ No newline at end of file
__version__
=
'1.0.1'
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment