Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
D2O
Commits
b1a824c8
Commit
b1a824c8
authored
May 26, 2016
by
theos
Browse files
Bugfixes: Added h5py to dependency_injector. Fixed apply_scalar_function (np.vectorize) and argmax.
parent
889b7964
Pipeline
#3842
skipped
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
d2o/config/d2o_config.py
View file @
b1a824c8
...
...
@@ -4,8 +4,9 @@ import keepers
# Setup the dependency injector
dependency_injector
=
keepers
.
DependencyInjector
(
[(
'mpi4py.MPI'
,
'MPI'
),
(
'mpi_dummy'
,
'MPI_dummy'
)]
[
'h5py'
,
(
'mpi4py.MPI'
,
'MPI'
),
(
'd2o.mpi_dummy.mpi_dummy'
,
'MPI_dummy'
)]
)
dependency_injector
.
register
(
'pyfftw'
,
lambda
z
:
hasattr
(
z
,
'FFTW_MPI'
))
...
...
d2o/distributed_data_object.py
View file @
b1a824c8
...
...
@@ -417,7 +417,8 @@ class distributed_data_object(object):
except
:
about_warnings_cprint
(
"WARNING: Trying to use np.vectorize!"
)
result_data
=
np
.
vectorize
(
function
)(
local_data
)
result_data
=
np
.
vectorize
(
function
,
otypes
=
[
local_data
.
dtype
])(
local_data
)
if
inplace
is
True
:
result_d2o
=
self
...
...
@@ -1244,13 +1245,17 @@ class distributed_data_object(object):
"keyword"
)
if
0
in
self
.
local_shape
:
local_argmax
=
np
.
nan
local_argmax_value
=
np
.
nan
local_argmax_value
=
-
np
.
inf
globalized_local_argmax
=
np
.
nan
else
:
local_argmax
=
np
.
argmax
(
self
.
data
)
local_argmax_value
=
-
self
.
data
[
np
.
unravel_index
(
local_argmax
,
self
.
data
.
shape
)]
globalized_local_argmax
=
self
.
distributor
.
globalize_flat_index
(
local_argmax_value
=
self
.
data
[
np
.
unravel_index
(
local_argmax
,
self
.
data
.
shape
)]
# instead of inverting the sign of local_argmax_value, invert
# the value of the index. Inverting the former leads to errors
# when the dtype is unsigned (uint). By inverting the latter
# we can extract the last entry from the sorted list below
globalized_local_argmax
=
-
self
.
distributor
.
globalize_flat_index
(
local_argmax
)
local_argmax_list
=
self
.
distributor
.
_allgather
(
(
local_argmax_value
,
...
...
@@ -1260,7 +1265,8 @@ class distributed_data_object(object):
(
'index'
,
np
.
dtype
(
'float'
))])
local_argmax_list
=
np
.
sort
(
local_argmax_list
,
order
=
[
'value'
,
'index'
])
return
np
.
int
(
local_argmax_list
[
0
][
1
])
# take the last entry here and correct the minus sign of the index
return
-
np
.
int
(
local_argmax_list
[
-
1
][
1
])
def
argmin_nonflat
(
self
,
axis
=
None
):
""" Returns the unraveld index of the d2o's smallest value.
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment