Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
23657529
Commit
23657529
authored
Apr 26, 2016
by
theos
Browse files
Fixed number.Numbers checks. Fixed failing of MPI.MIN and MPI.MAX for complex dtypes.
parent
f39083b1
Pipeline
#2004
skipped
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
d2o/cast_axis_to_tuple.py
View file @
23657529
# -*- coding: utf-8 -*-
import
numbers
import
numpy
as
np
from
nifty
import
about
...
...
@@ -11,7 +10,7 @@ def cast_axis_to_tuple(axis):
try
:
axis
=
tuple
([
int
(
item
)
for
item
in
axis
])
except
(
TypeError
):
if
isinstance
(
axis
,
numbers
.
Number
):
if
np
.
isscalar
(
axis
):
axis
=
(
int
(
axis
),
)
else
:
raise
TypeError
(
about
.
_errors
.
cstring
(
...
...
d2o/distributed_data_object.py
View file @
23657529
# -*- coding: utf-8 -*-
import
numbers
as
numbers
import
numpy
as
np
from
nifty.keepers
import
about
,
\
...
...
@@ -1177,7 +1175,7 @@ class distributed_data_object(object):
def
std
(
self
,
axis
=
None
):
""" Returns the standard deviation of the d2o's elements. """
var
=
self
.
var
(
axis
=
axis
)
if
isinstance
(
var
,
numbers
.
Numbe
r
):
if
np
.
isscalar
(
va
r
):
return
np
.
sqrt
(
var
)
else
:
return
var
.
apply_scalar_function
(
np
.
sqrt
)
...
...
@@ -1294,7 +1292,7 @@ class distributed_data_object(object):
about
.
warnings
.
cprint
(
"WARNING: The current implementation of median is very expensive!"
)
median
=
np
.
median
(
self
.
get_full_data
(),
axis
=
axis
,
**
kwargs
)
if
isinstance
(
median
,
numbers
.
Number
):
if
np
.
isscalar
(
median
):
return
median
else
:
x
=
self
.
copy_empty
(
global_shape
=
median
.
shape
,
...
...
@@ -1303,7 +1301,6 @@ class distributed_data_object(object):
x
.
set_local_data
(
median
)
return
x
def
_is_helper
(
self
,
function
):
""" _is_helper is used for functions like isreal, isinf, isfinite,...
...
...
d2o/distributor_factory.py
View file @
23657529
...
...
@@ -517,7 +517,14 @@ class _slicing_distributor(distributor):
# check if additional contraction along the first axis must be done
if
axis
is
None
or
0
in
axis
:
(
mpi_op
,
bufferQ
)
=
op_translate_dict
[
function
]
# check if allreduce must be used instead of Allreduce
use_Uppercase
=
False
if
bufferQ
and
isinstance
(
contracted_local_data
,
np
.
ndarray
):
# MPI.MAX and MPI.MIN do not support complex data types
if
not
np
.
issubdtype
(
contracted_local_data
.
dtype
,
np
.
complexfloating
):
use_Uppercase
=
True
if
use_Uppercase
:
global_contracted_local_data
=
np
.
empty_like
(
contracted_local_data
)
new_mpi_dtype
=
self
.
_my_dtype_converter
.
to_mpi
(
new_dtype
)
...
...
dummys/MPI_dummy.py
View file @
23657529
# -*- coding: utf-8 -*-
import
numbers
import
copy
import
numpy
as
np
...
...
@@ -89,7 +88,7 @@ class Intracomm(Comm):
return
recvbuf
def
allreduce
(
self
,
sendobj
,
op
=
SUM
,
**
kwargs
):
if
isinstance
(
sendobj
,
numbers
.
Number
):
if
np
.
isscalar
(
sendobj
):
return
sendobj
return
copy
.
copy
(
sendobj
)
...
...
test/test_nifty_mpi_data.py
View file @
23657529
...
...
@@ -1049,8 +1049,8 @@ class Test_set_data_via_injection(unittest.TestCase):
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_set_data_via_injection
(
self
,
(
global_shape_1
,
slice_tuple_1
,
global_shape_2
,
slice_tuple_2
),
distribution_strategy
):
global_shape_2
,
slice_tuple_2
),
distribution_strategy
):
dtype
=
np
.
dtype
(
'float'
)
(
a
,
obj
)
=
generate_data
(
global_shape_1
,
dtype
,
distribution_strategy
)
...
...
@@ -1059,8 +1059,8 @@ class Test_set_data_via_injection(unittest.TestCase):
distribution_strategy
)
obj
.
set_data
(
to_key
=
slice_tuple_1
,
data
=
p
,
from_key
=
slice_tuple_2
)
data
=
p
,
from_key
=
slice_tuple_2
)
a
[
slice_tuple_1
]
=
b
[
slice_tuple_2
]
assert_equal
(
obj
.
get_full_data
(),
a
)
...
...
@@ -1601,9 +1601,9 @@ class Test_comparisons(unittest.TestCase):
class
Test_special_methods
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
(
all_distribution_strategies
,
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
itertools
.
product
(
all_distribution_strategies
,
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_bincount
(
self
,
distribution_strategy_1
,
distribution_strategy_2
):
global_shape
=
(
10
,)
dtype
=
np
.
dtype
(
'int'
)
...
...
@@ -1742,8 +1742,8 @@ if FOUND['h5py'] == True:
class
Test_axis
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
#
'median',
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
all_datatypes
[
1
:],
...
...
@@ -1774,8 +1774,8 @@ class Test_axis(unittest.TestCase):
decimal
=
4
)
@
parameterized
.
expand
(
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
itertools
.
product
([
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
#
'median',
'all'
,
'any'
,
'min'
,
'amin'
,
'nanmin'
,
'argmin'
,
'argmin_nonflat'
,
'max'
,
'amax'
,
'nanmax'
,
'argmax'
,
'argmax_nonflat'
],
all_datatypes
[
1
:],
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment