Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
c448ab64
Commit
c448ab64
authored
Apr 28, 2016
by
theos
Browse files
Fixed the distribution_strategy handling of cumsum.
parent
d66fd65f
Changes
2
Hide whitespace changes
Inline
Side-by-side
d2o/distributed_data_object.py
View file @
c448ab64
...
...
@@ -1798,19 +1798,7 @@ class distributed_data_object(object):
Contains the results of the cummulative sum.
"""
cumsum_data
=
self
.
distributor
.
cumsum
(
self
.
data
,
axis
=
axis
)
if
axis
is
None
:
flat_global_shape
=
(
np
.
prod
(
self
.
shape
),)
flat_local_shape
=
np
.
shape
(
cumsum_data
)
result_d2o
=
self
.
copy_empty
(
global_shape
=
flat_global_shape
,
local_shape
=
flat_local_shape
)
else
:
result_d2o
=
self
.
copy_empty
()
result_d2o
.
set_local_data
(
cumsum_data
)
return
result_d2o
return
self
.
distributor
.
cumsum
(
parent
=
self
,
axis
=
axis
)
def
save
(
self
,
alias
,
path
=
None
,
overwriteQ
=
True
):
""" Saves the distributed_data_object to disk utilizing h5py.
...
...
d2o/distributor_factory.py
View file @
c448ab64
...
...
@@ -1506,7 +1506,8 @@ class _slicing_distributor(distributor):
MPI
.
SUM
)
return
global_counts
def
cumsum
(
self
,
data
,
axis
):
def
cumsum
(
self
,
parent
,
axis
):
data
=
parent
.
data
# compute the local np.cumsum
local_cumsum
=
np
.
cumsum
(
data
,
axis
=
axis
)
if
axis
is
None
or
axis
==
0
:
...
...
@@ -1520,7 +1521,28 @@ class _slicing_distributor(distributor):
local_sum_of_shift
=
np
.
sum
(
local_shift_list
[:
rank
],
axis
=
0
)
local_cumsum
+=
local_sum_of_shift
return
local_cumsum
# create the return d2o
if
axis
is
None
:
# try to preserve the distribution_strategy
flat_global_shape
=
(
self
.
global_dim
,
)
flat_local_shape
=
np
.
shape
(
local_cumsum
)
result_d2o
=
parent
.
copy_empty
(
global_shape
=
flat_global_shape
,
local_shape
=
flat_local_shape
)
# check if the original distribution strategy yielded a suitable
# local_shape
if
result_d2o
.
local_shape
!=
flat_local_shape
:
# if it does not fit, construct a freeform d2o
result_d2o
=
parent
.
copy_empty
(
global_shape
=
flat_global_shape
,
local_shape
=
flat_local_shape
,
distribution_strategy
=
'freeform'
)
else
:
result_d2o
=
parent
.
copy_empty
()
result_d2o
.
set_local_data
(
local_cumsum
,
copy
=
False
)
return
result_d2o
def
_sliceify
(
self
,
inp
):
sliceified
=
[]
...
...
@@ -2001,10 +2023,13 @@ class _not_distributor(distributor):
minlength
=
minlength
)
return
counts
def
cumsum
(
self
,
data
,
axis
):
def
cumsum
(
self
,
parent
,
axis
):
data
=
parent
.
data
# compute the local results from np.cumsum
cumsum
=
np
.
cumsum
(
data
,
axis
=
axis
)
return
cumsum
local_cumsum
=
np
.
cumsum
(
data
,
axis
=
axis
)
result_d2o
=
parent
.
copy_empty
(
global_shape
=
local_cumsum
.
shape
)
result_d2o
.
set_local_data
(
local_cumsum
,
copy
=
False
)
return
result_d2o
if
'h5py'
in
gdi
:
def
save_data
(
self
,
data
,
alias
,
path
=
None
,
overwriteQ
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment