Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
D2O
Commits
65098d3f
Commit
65098d3f
authored
Jul 04, 2016
by
theos
Browse files
Added a simple search-sorted functionality.
Fixed 'mean' method.
parent
50113262
Changes
4
Hide whitespace changes
Inline
Side-by-side
d2o/cast_axis_to_tuple.py
View file @
65098d3f
...
...
@@ -19,15 +19,18 @@
import
numpy
as
np
def
cast_axis_to_tuple
(
axis
):
def
cast_axis_to_tuple
(
axis
,
length
):
if
axis
is
None
:
return
None
try
:
axis
=
tuple
(
[
int
(
item
)
for
item
in
axis
]
)
axis
=
tuple
(
int
(
item
)
for
item
in
axis
)
except
(
TypeError
):
if
np
.
isscalar
(
axis
):
axis
=
(
int
(
axis
),
)
else
:
raise
TypeError
(
"ERROR: Could not convert axis-input to tuple of ints"
)
# shift negative indices to positive ones
axis
=
tuple
(
item
if
(
item
>=
0
)
else
(
item
+
length
)
for
item
in
axis
)
return
axis
d2o/distributed_data_object.py
View file @
65098d3f
...
...
@@ -1172,7 +1172,7 @@ class distributed_data_object(object):
def
mean
(
self
,
axis
=
None
,
**
kwargs
):
# infer, which axes will be collapsed
axis
=
cast_axis_to_tuple
(
axis
)
axis
=
cast_axis_to_tuple
(
axis
,
length
=
len
(
self
.
shape
)
)
if
axis
is
None
:
collapsed_shapes
=
self
.
shape
else
:
...
...
@@ -1821,6 +1821,30 @@ class distributed_data_object(object):
return
self
.
distributor
.
cumsum
(
parent
=
self
,
axis
=
axis
)
def
searchsorted
(
self
,
v
,
side
=
'left'
):
""" Find indices where elements should be inserted to maintain order.
Find the indices into a sorted array `a` such that, if the
corresponding elements in `v` were inserted before the indices, the
order of `a` would be preserved.
Parameters
----------
a : 1-D array_like
Input array. If `sorter` is None, then it must be sorted in
ascending order, otherwise `sorter` must be an array of indices
that sort it.
v : array_like
Values to insert into `a`.
side : {'left', 'right'}, optional
If 'left', the index of the first suitable location found is given.
If 'right', return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `a`).
"""
return
self
.
distributor
.
searchsorted
(
obj
=
self
,
v
=
v
,
side
=
side
)
def
save
(
self
,
alias
,
path
=
None
,
overwriteQ
=
True
):
""" Saves the distributed_data_object to disk utilizing h5py.
...
...
d2o/distributor_factory.py
View file @
65098d3f
...
...
@@ -588,7 +588,7 @@ class _slicing_distributor(distributor):
return
parent
.
copy
()
old_shape
=
parent
.
shape
axis
=
cast_axis_to_tuple
(
axis
)
axis
=
cast_axis_to_tuple
(
axis
,
length
=
len
(
self
.
global_shape
)
)
if
axis
is
None
:
new_shape
=
()
else
:
...
...
@@ -1551,6 +1551,25 @@ class _slicing_distributor(distributor):
return
result_d2o
def
searchsorted
(
self
,
obj
,
v
,
side
=
'left'
):
a
=
obj
.
get_local_data
(
copy
=
False
)
local_searched
=
np
.
searchsorted
(
a
=
a
,
v
=
v
,
side
=
side
)
global_searched
=
np
.
empty_like
(
local_searched
)
if
side
is
'left'
:
op
=
MPI
.
MAX
elif
side
is
'right'
:
op
=
MPI
.
MIN
else
:
raise
ValueError
self
.
comm
.
Allreduce
([
local_searched
,
MPI
.
INT
],
[
global_searched
,
MPI
.
INT
],
op
=
op
)
if
global_searched
.
shape
==
():
global_searched
=
global_searched
[()]
return
global_searched
def
_sliceify
(
self
,
inp
):
sliceified
=
[]
result
=
[]
...
...
@@ -2042,6 +2061,10 @@ class _not_distributor(distributor):
result_d2o
.
set_local_data
(
local_cumsum
,
copy
=
False
)
return
result_d2o
def
searchsorted
(
self
,
obj
,
v
,
side
=
'left'
):
a
=
obj
.
get_local_data
(
copy
=
False
)
return
np
.
searchsorted
(
a
=
a
,
v
=
v
,
side
=
side
)
if
'h5py'
in
gdi
:
def
save_data
(
self
,
data
,
alias
,
path
=
None
,
overwriteQ
=
True
):
comm
=
self
.
comm
...
...
d2o/version.py
View file @
65098d3f
...
...
@@ -20,4 +20,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__
=
'1.0.0'
\ No newline at end of file
__version__
=
'1.0.1'
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment