Commit 23657529 authored by theos's avatar theos
Browse files

Fixed number.Numbers checks. Fixed failing of MPI.MIN and MPI.MAX for complex dtypes.

parent f39083b1
# -*- coding: utf-8 -*-
import numbers
import numpy as np
from nifty import about
......@@ -11,7 +10,7 @@ def cast_axis_to_tuple(axis):
try:
axis = tuple([int(item) for item in axis])
except(TypeError):
if isinstance(axis, numbers.Number):
if np.isscalar(axis):
axis = (int(axis), )
else:
raise TypeError(about._errors.cstring(
......
# -*- coding: utf-8 -*-
import numbers as numbers
import numpy as np
from nifty.keepers import about,\
......@@ -1177,7 +1175,7 @@ class distributed_data_object(object):
def std(self, axis=None):
""" Returns the standard deviation of the d2o's elements. """
var = self.var(axis=axis)
if isinstance(var, numbers.Number):
if np.isscalar(var):
return np.sqrt(var)
else:
return var.apply_scalar_function(np.sqrt)
......@@ -1294,7 +1292,7 @@ class distributed_data_object(object):
about.warnings.cprint(
"WARNING: The current implementation of median is very expensive!")
median = np.median(self.get_full_data(), axis=axis, **kwargs)
if isinstance(median, numbers.Number):
if np.isscalar(median):
return median
else:
x = self.copy_empty(global_shape=median.shape,
......@@ -1303,7 +1301,6 @@ class distributed_data_object(object):
x.set_local_data(median)
return x
def _is_helper(self, function):
""" _is_helper is used for functions like isreal, isinf, isfinite,...
......
......@@ -517,7 +517,14 @@ class _slicing_distributor(distributor):
# check if additional contraction along the first axis must be done
if axis is None or 0 in axis:
(mpi_op, bufferQ) = op_translate_dict[function]
# check if allreduce must be used instead of Allreduce
use_Uppercase = False
if bufferQ and isinstance(contracted_local_data, np.ndarray):
# MPI.MAX and MPI.MIN do not support complex data types
if not np.issubdtype(contracted_local_data.dtype,
np.complexfloating):
use_Uppercase = True
if use_Uppercase:
global_contracted_local_data = np.empty_like(
contracted_local_data)
new_mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
......
# -*- coding: utf-8 -*-
import numbers
import copy
import numpy as np
......@@ -89,7 +88,7 @@ class Intracomm(Comm):
return recvbuf
def allreduce(self, sendobj, op=SUM, **kwargs):
if isinstance(sendobj, numbers.Number):
if np.isscalar(sendobj):
return sendobj
return copy.copy(sendobj)
......
......@@ -1742,8 +1742,8 @@ if FOUND['h5py'] == True:
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'argmin',
itertools.product(['sum', 'prod', 'mean', 'var', 'std', #'median',
'all', 'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
......@@ -1774,8 +1774,8 @@ class Test_axis(unittest.TestCase):
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'argmin',
itertools.product(['sum', 'prod', 'mean', 'var', 'std', #'median',
'all', 'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
all_datatypes[1:],
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment