Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
D2O
Commits
fd33af41
Commit
fd33af41
authored
Jul 09, 2016
by
theos
Browse files
Improved set_full_data.
parent
2621861d
Changes
1
Hide whitespace changes
Inline
Side-by-side
d2o/distributor_factory.py
View file @
fd33af41
...
...
@@ -300,6 +300,41 @@ def _infer_key_type(key):
class
distributor
(
object
):
def
distribute_data
(
self
,
data
=
None
,
alias
=
None
,
path
=
None
,
copy
=
True
,
**
kwargs
):
'''
distribute data checks
- whether the data is located on all nodes or only on node 0
- that the shape of 'data' matches the global_shape
'''
if
'h5py'
in
gdi
and
alias
is
not
None
:
data
=
self
.
load_data
(
alias
=
alias
,
path
=
path
)
if
data
is
None
:
return
np
.
empty
(
self
.
local_shape
,
dtype
=
self
.
dtype
)
elif
np
.
isscalar
(
data
):
return
np
.
ones
(
self
.
local_shape
,
dtype
=
self
.
dtype
)
*
data
elif
isinstance
(
data
,
np
.
ndarray
)
or
\
isinstance
(
data
,
distributed_data_object
):
data
=
self
.
extract_local_data
(
data
)
if
data
.
shape
is
not
self
.
local_shape
:
copy
=
True
if
copy
:
result_data
=
np
.
empty
(
self
.
local_shape
,
dtype
=
self
.
dtype
)
result_data
[:]
=
data
else
:
result_data
=
data
return
result_data
else
:
new_data
=
np
.
array
(
data
)
return
new_data
.
astype
(
self
.
dtype
,
copy
=
copy
).
reshape
(
self
.
local_shape
)
def
disperse_data
(
self
,
data
,
to_key
,
data_update
,
from_key
=
None
,
local_keys
=
False
,
copy
=
True
,
**
kwargs
):
# Check which keys we got:
...
...
@@ -456,18 +491,24 @@ class distributor(object):
# bincounts
for
slice_list
in
slicing_generator
(
flat_shape
,
axes
=
(
len
(
flat_shape
)
-
1
,
)):
local_counts
[
slice_list
]
=
np
.
bincount
(
data
[
slice_list
],
weights
=
local_weights
,
minlength
=
length
)
if
local_weights
is
not
None
:
current_weights
=
local_weights
[
slice_list
]
else
:
current_weights
=
None
local_counts
[
slice_list
]
=
np
.
bincount
(
data
[
slice_list
],
weights
=
current_weights
,
minlength
=
length
)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if
axis
is
not
None
:
# axis has been sorted above
insert_position
=
axis
[
0
]
new_ndim
=
len
(
local_counts
.
shape
)
return_order
=
(
range
(
0
,
insert_position
)
+
[
ndim
-
1
,
]
+
range
(
insert_position
,
ndim
-
1
))
[
new_
ndim
-
1
,
]
+
range
(
insert_position
,
new_
ndim
-
1
))
local_counts
=
np
.
ascontiguousarray
(
local_counts
.
transpose
(
return_order
))
return
self
.
_combine_local_bincount_counts
(
obj
,
local_counts
,
axis
)
...
...
@@ -714,43 +755,69 @@ class _slicing_distributor(distributor):
return
result
def
distribute_data
(
self
,
data
=
None
,
alias
=
None
,
path
=
None
,
copy
=
True
,
**
kwargs
):
'''
distribute data checks
- whether the data is located on all nodes or only on node 0
- that the shape of 'data' matches the global_shape
'''
comm
=
self
.
comm
if
'h5py'
in
gdi
and
alias
is
not
None
:
data
=
self
.
load_data
(
alias
=
alias
,
path
=
path
)
local_data_available_Q
=
(
data
is
not
None
)
data_available_Q
=
np
.
array
(
comm
.
allgather
(
local_data_available_Q
))
if
np
.
all
(
data_available_Q
==
False
):
return
np
.
empty
(
self
.
local_shape
,
dtype
=
self
.
dtype
,
order
=
'C'
)
# if all nodes got data, we assume that it is the right data and
# store it individually.
elif
np
.
all
(
data_available_Q
==
True
):
if
isinstance
(
data
,
distributed_data_object
):
temp_d2o
=
data
.
get_data
((
slice
(
self
.
local_start
,
self
.
local_end
),),
local_keys
=
True
,
copy
=
copy
)
return
temp_d2o
.
get_local_data
(
copy
=
False
).
astype
(
self
.
dtype
,
copy
=
False
)
elif
np
.
isscalar
(
data
):
return
np
.
ones
(
self
.
local_shape
,
dtype
=
self
.
dtype
)
*
data
else
:
return
data
[
self
.
local_start
:
self
.
local_end
].
astype
(
self
.
dtype
,
copy
=
copy
)
else
:
raise
ValueError
(
"ERROR: distribute_data must get data on all nodes!"
)
# def distribute_data(self, data=None, alias=None,
# path=None, copy=True, **kwargs):
# '''
# distribute data checks
# - whether the data is located on all nodes or only on node 0
# - that the shape of 'data' matches the global_shape
# '''
#
## comm = self.comm
#
# if 'h5py' in gdi and alias is not None:
# data = self.load_data(alias=alias, path=path)
#
# if data is None:
# return np.empty(self.global_shape, dtype=self.dtype)
# elif np.isscalar(data):
# return np.ones(self.global_shape, dtype=self.dtype)*data
# copy = False
# elif isinstance(data, np.ndarray) or \
# isinstance(data, distributed_data_object):
# data = self.extract_local_data(data)
#
# if data.shape is not self.local_shape:
# copy = True
#
# if copy:
# result_data = np.empty(self.local_shape, dtype=self.dtype)
# result_data[:] = data
# else:
# result_data = data
#
# return result_data
#
# else:
# new_data = np.array(data)
# return new_data.astype(self.dtype,
# copy=copy).reshape(self.global_shape)
#
#
## local_data_available_Q = (data is not None)
## data_available_Q = np.array(comm.allgather(local_data_available_Q))
##
## if np.all(data_available_Q == False):
## return np.empty(self.local_shape, dtype=self.dtype, order='C')
## # if all nodes got data, we assume that it is the right data and
## # store it individually.
## elif np.all(data_available_Q == True):
## if isinstance(data, distributed_data_object):
## temp_d2o = data.get_data((slice(self.local_start,
## self.local_end),),
## local_keys=True,
## copy=copy)
## return temp_d2o.get_local_data(copy=False).astype(self.dtype,
## copy=False)
## elif np.isscalar(data):
## return np.ones(self.local_shape, dtype=self.dtype)*data
## else:
## return data[self.local_start:self.local_end].astype(
## self.dtype,
## copy=copy)
## else:
## raise ValueError(
## "ERROR: distribute_data must get data on all nodes!")
def
_disperse_data_primitive
(
self
,
data
,
to_key
,
data_update
,
from_key
,
copy
,
to_found
,
to_found_boolean
,
from_found
,
...
...
@@ -1403,37 +1470,38 @@ class _slicing_distributor(distributor):
# if shape-casting was successfull, extract the data
else
:
# If the first dimension matches only via broadcasting...
# Case 1: ...do broadcasting. This procedure does not depend on the
# array type (ndarray or d2o)
if
matching_dimensions
[
0
]
==
False
:
extracted_data
=
data_object
[
0
:
1
]
# Case 2: First dimension fits directly and data_object is a d2o
elif
isinstance
(
data_object
,
distributed_data_object
):
# Check if both d2os have the same slicing
# If the distributor is exactly the same, extract the data
if
self
is
data_object
.
distributor
:
# Simply take the local data
extracted_data
=
data_object
.
data
# If the distributor is not exactly the same, check if the
# geometry matches if it is a slicing distributor
# -> comm and local shapes
elif
(
isinstance
(
data_object
.
distributor
,
_slicing_distributor
)
and
(
self
.
comm
is
data_object
.
distributor
.
comm
)
and
(
np
.
all
(
self
.
all_local_slices
==
data_object
.
distributor
.
all_local_slices
))):
if
isinstance
(
data_object
,
distributed_data_object
):
# If the first dimension matches only via broadcasting...
# Case 1: ...do broadcasting.
if
matching_dimensions
[
0
]
==
False
:
extracted_data
=
data_object
.
get_full_data
()
extracted_data
=
extracted_data
[
0
]
else
:
# Case 2: First dimension fits directly and data_object is
# a d2o
# Check if both d2os have the same slicing
# If the distributor is exactly the same, extract the data
if
self
is
data_object
.
distributor
:
# Simply take the local data
extracted_data
=
data_object
.
data
# If the distributor is not exactly the same, check if the
# geometry matches if it is a slicing distributor
# -> comm and local shapes
elif
(
isinstance
(
data_object
.
distributor
,
_slicing_distributor
)
and
(
self
.
comm
is
data_object
.
distributor
.
comm
)
and
(
np
.
all
(
self
.
all_local_slices
==
data_object
.
distributor
.
all_local_slices
))):
extracted_data
=
data_object
.
data
else
:
# Case 2: no. All nodes extract their local slice from the
# data_object
extracted_data
=
\
data_object
.
get_data
(
slice
(
self
.
local_start
,
self
.
local_end
),
local_keys
=
True
)
extracted_data
=
extracted_data
.
get_local_data
()
else
:
# Case 2: no. All nodes extract their local slice from the
# data_object
extracted_data
=
\
data_object
.
get_data
(
slice
(
self
.
local_start
,
self
.
local_end
),
local_keys
=
True
)
extracted_data
=
extracted_data
.
get_local_data
()
# # Check if the distributor and the comm match
...
...
@@ -1454,7 +1522,11 @@ class _slicing_distributor(distributor):
# extracted_data = extracted_data.get_local_data()
#
#
# Case 2: np-array
# If the first dimension matches only via broadcasting
# ...do broadcasting.
elif
matching_dimensions
[
0
]
==
False
:
extracted_data
=
data_object
[
0
:
1
]
# Case 3: First dimension fits directly and data_object is an
# generic array
else
:
...
...
@@ -1464,6 +1536,7 @@ class _slicing_distributor(distributor):
return
extracted_data
def
_reshape_foreign_data
(
self
,
foreign
):
# Case 1:
# check if the shapes match directly
if
self
.
global_shape
==
foreign
.
shape
:
...
...
@@ -2043,24 +2116,42 @@ class _not_distributor(distributor):
return
result_object
def
distribute_data
(
self
,
data
,
alias
=
None
,
path
=
None
,
copy
=
True
,
**
kwargs
):
if
'h5py'
in
gdi
and
alias
is
not
None
:
data
=
self
.
load_data
(
alias
=
alias
,
path
=
path
)
if
data
is
None
:
return
np
.
empty
(
self
.
global_shape
,
dtype
=
self
.
dtype
)
elif
isinstance
(
data
,
distributed_data_object
):
new_data
=
data
.
get_full_data
()
elif
isinstance
(
data
,
np
.
ndarray
):
new_data
=
data
elif
np
.
isscalar
(
data
):
new_data
=
np
.
ones
(
self
.
global_shape
,
dtype
=
self
.
dtype
)
*
data
copy
=
False
else
:
new_data
=
np
.
array
(
data
)
return
new_data
.
astype
(
self
.
dtype
,
copy
=
copy
).
reshape
(
self
.
global_shape
)
# def distribute_data(self, data, alias=None, path=None, copy=True,
# **kwargs):
# if 'h5py' in gdi and alias is not None:
# data = self.load_data(alias=alias, path=path)
#
# if data is None:
# return np.empty(self.global_shape, dtype=self.dtype)
# elif np.isscalar(data):
# return np.ones(self.global_shape, dtype=self.dtype)*data
# copy = False
# elif isinstance(data, np.ndarray) or \
# isinstance(data, distributed_data_object):
# data = self.extract_local_data(data)
# result_data = np.empty(self.local_shape, dtype=self.dtype)
# result_data[:] = data
# return result_data
#
# else:
# new_data = np.array(data)
# return new_data.astype(self.dtype,
# copy=copy).reshape(self.global_shape)
#
#
## if data is None:
## return np.empty(self.global_shape, dtype=self.dtype)
## elif isinstance(data, distributed_data_object):
## new_data = data.get_full_data()
## elif isinstance(data, np.ndarray):
## new_data = data
## elif np.isscalar(data):
## new_data = np.ones(self.global_shape, dtype=self.dtype)*data
## copy = False
## else:
## new_data = np.array(data)
## return new_data.astype(self.dtype,
## copy=copy).reshape(self.global_shape)
def
_disperse_data_primitive
(
self
,
data
,
to_key
,
data_update
,
from_key
,
copy
,
to_found
,
to_found_boolean
,
from_found
,
...
...
@@ -2118,9 +2209,15 @@ class _not_distributor(distributor):
def
extract_local_data
(
self
,
data_object
):
if
isinstance
(
data_object
,
distributed_data_object
):
re
turn
data_object
.
get_full_data
()
.
reshape
(
self
.
global_shape
)
re
sult_data
=
data_object
.
get_full_data
()
else
:
return
np
.
array
(
data_object
)[:].
reshape
(
self
.
global_shape
)
result_data
=
np
.
array
(
data_object
)[:]
try
:
result_data
=
result_data
.
reshape
(
self
.
global_shape
)
except
ValueError
:
pass
return
result_data
def
flatten
(
self
,
data
,
inplace
=
False
):
if
inplace
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment