Commit 154a163f authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'minor-fixes-for-axis' into 'add_axis_keyword_to_d2o'

Minor fixes for axis



See merge request !10
parents f5a9d3ac 5599041c
Pipeline #2029 skipped
# -*- coding: utf-8 -*-
import numbers
import numpy as np
from nifty import about
......@@ -11,7 +10,7 @@ def cast_axis_to_tuple(axis):
try:
axis = tuple([int(item) for item in axis])
except(TypeError):
if isinstance(axis, numbers.Number):
if np.isscalar(axis):
axis = (int(axis), )
else:
raise TypeError(about._errors.cstring(
......
# -*- coding: utf-8 -*-
import numbers as numbers
import numpy as np
from nifty.keepers import about,\
......@@ -1177,7 +1175,7 @@ class distributed_data_object(object):
def std(self, axis=None):
""" Returns the standard deviation of the d2o's elements. """
var = self.var(axis=axis)
if isinstance(var, numbers.Number):
if np.isscalar(var):
return np.sqrt(var)
else:
return var.apply_scalar_function(np.sqrt)
......@@ -1294,7 +1292,7 @@ class distributed_data_object(object):
about.warnings.cprint(
"WARNING: The current implementation of median is very expensive!")
median = np.median(self.get_full_data(), axis=axis, **kwargs)
if isinstance(median, numbers.Number):
if np.isscalar(median):
return median
else:
x = self.copy_empty(global_shape=median.shape,
......@@ -1303,7 +1301,6 @@ class distributed_data_object(object):
x.set_local_data(median)
return x
def _is_helper(self, function):
""" _is_helper is used for functions like isreal, isinf, isfinite,...
......
......@@ -517,14 +517,32 @@ class _slicing_distributor(distributor):
# check if additional contraction along the first axis must be done
if axis is None or 0 in axis:
(mpi_op, bufferQ) = op_translate_dict[function]
contracted_local_data = self.comm.allreduce(contracted_local_data,
op=mpi_op)
# check if allreduce must be used instead of Allreduce
use_Uppercase = False
if bufferQ and isinstance(contracted_local_data, np.ndarray):
# MPI.MAX and MPI.MIN do not support complex data types
if not np.issubdtype(contracted_local_data.dtype,
np.complexfloating):
use_Uppercase = True
if use_Uppercase:
global_contracted_local_data = np.empty_like(
contracted_local_data)
new_mpi_dtype = self._my_dtype_converter.to_mpi(new_dtype)
self.comm.Allreduce([contracted_local_data,
new_mpi_dtype],
[global_contracted_local_data,
new_mpi_dtype],
op=mpi_op)
else:
global_contracted_local_data = self.comm.allreduce(
contracted_local_data, op=mpi_op)
new_dist_strategy = 'not'
else:
new_dist_strategy = parent.distribution_strategy
global_contracted_local_data = contracted_local_data
if new_shape == ():
result = contracted_local_data
result = global_contracted_local_data
else:
# try to store the result in a distributed_data_object with the
# distribution_strategy as parent
......@@ -538,12 +556,12 @@ class _slicing_distributor(distributor):
# Contracting (4, 4) to (4,).
# (4, 4) was distributed (1, 4)...(1, 4)
# (4, ) is not distributed like (1,)...(1,) but like (2,)(2,)()()!
if result.local_shape != contracted_local_data.shape:
if result.local_shape != global_contracted_local_data.shape:
result = parent.copy_empty(
local_shape=contracted_local_data.shape,
local_shape=global_contracted_local_data.shape,
dtype=new_dtype,
distribution_strategy='freeform')
result.set_local_data(contracted_local_data, copy=False)
result.set_local_data(global_contracted_local_data, copy=False)
return result
......
......@@ -8,10 +8,10 @@ from nifty.keepers import global_configuration as gc,\
MPI = gdi[gc['mpi_module']]
custom_NANMIN = MPI.Op.Create(lambda x, y, datatype:
np.nanmin(np.vstack(x, y), axis=0))
np.nanmin(np.vstack((x, y)), axis=0))
custom_NANMAX = MPI.Op.Create(lambda x, y, datatype:
np.nanmax(np.vstack(x, y), axis=0))
np.nanmax(np.vstack((x, y)), axis=0))
custom_UNIQUE = MPI.Op.Create(lambda x, y, datatype:
np.unique(np.concatenate([x, y])))
......@@ -24,8 +24,8 @@ op_translate_dict[np.sum] = (MPI.SUM, True)
op_translate_dict[np.prod] = (MPI.PROD, True)
op_translate_dict[np.amin] = (MPI.MIN, True)
op_translate_dict[np.amax] = (MPI.MAX, True)
op_translate_dict[np.all] = (MPI.LAND, True)
op_translate_dict[np.any] = (MPI.LOR, True)
op_translate_dict[np.all] = (MPI.BAND, True)
op_translate_dict[np.any] = (MPI.BOR, True)
op_translate_dict[np.nanmin] = (custom_NANMIN, False)
op_translate_dict[np.nanmax] = (custom_NANMAX, False)
op_translate_dict[np.unique] = (custom_UNIQUE, False)
# -*- coding: utf-8 -*-
import numbers
import copy
import numpy as np
......@@ -89,7 +88,7 @@ class Intracomm(Comm):
return recvbuf
def allreduce(self, sendobj, op=SUM, **kwargs):
if isinstance(sendobj, numbers.Number):
if np.isscalar(sendobj):
return sendobj
return copy.copy(sendobj)
......
......@@ -17,6 +17,8 @@ import nifty
from nifty.d2o import distributed_data_object,\
STRATEGIES
from distutils.version import LooseVersion as lv
FOUND = {}
try:
import h5py
......@@ -1049,8 +1051,8 @@ class Test_set_data_via_injection(unittest.TestCase):
all_distribution_strategies
), testcase_func_name=custom_name_func)
def test_set_data_via_injection(self, (global_shape_1, slice_tuple_1,
global_shape_2, slice_tuple_2),
distribution_strategy):
global_shape_2, slice_tuple_2),
distribution_strategy):
dtype = np.dtype('float')
(a, obj) = generate_data(global_shape_1, dtype,
distribution_strategy)
......@@ -1059,8 +1061,8 @@ class Test_set_data_via_injection(unittest.TestCase):
distribution_strategy)
obj.set_data(to_key=slice_tuple_1,
data=p,
from_key=slice_tuple_2)
data=p,
from_key=slice_tuple_2)
a[slice_tuple_1] = b[slice_tuple_2]
assert_equal(obj.get_full_data(), a)
......@@ -1601,9 +1603,9 @@ class Test_comparisons(unittest.TestCase):
class Test_special_methods(unittest.TestCase):
@parameterized.expand(
itertools.product(all_distribution_strategies,
all_distribution_strategies),
testcase_func_name=custom_name_func)
itertools.product(all_distribution_strategies,
all_distribution_strategies),
testcase_func_name=custom_name_func)
def test_bincount(self, distribution_strategy_1, distribution_strategy_2):
global_shape = (10,)
dtype = np.dtype('int')
......@@ -1742,12 +1744,11 @@ if FOUND['h5py'] == True:
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(0,), (1,), (6, 6), (5, 5, 5)],
[(0,), (1,), (6, 6), (4, 4, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1)]),
testcase_func_name=custom_name_func)
......@@ -1774,12 +1775,11 @@ class Test_axis(unittest.TestCase):
decimal=4)
@parameterized.expand(
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'min', 'amin', 'nanmin', 'argmin',
'argmin_nonflat', 'max', 'amax', 'nanmax',
'argmax', 'argmax_nonflat'],
itertools.product(['sum', 'prod', 'mean', 'var', 'std', 'all', 'any',
'min', 'amin', 'nanmin', 'argmin', 'argmin_nonflat',
'max', 'amax', 'nanmax', 'argmax', 'argmax_nonflat'],
all_datatypes[1:],
[(5, 5, 5), (4, 0, 3)],
[(4, 4, 3), (4, 0, 3)],
all_distribution_strategies, [(0, 1), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_with_functions_for_many_dimentions(self, function, dtype,
......@@ -1798,4 +1798,27 @@ class Test_axis(unittest.TestCase):
else:
assert_almost_equal(getattr(obj, function)
(axis=axis).get_full_data(),
getattr(np, function)(a, axis=axis), decimal=4)
getattr(np, function)(a, axis=axis),
decimal=4)
@parameterized.expand(
itertools.product(all_datatypes[1:],
[(0,), (1,), (4, 4, 3), (4, 0, 3)],
all_distribution_strategies,
[None, 0, (1, ), (0, 1), (1, 2)]),
testcase_func_name=custom_name_func)
def test_axis_for_median(self, dtype, global_shape,
distribution_strategy, axis):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
if global_shape != (0,) and global_shape != (1,) and lv("1.9.0Z") < \
lv(np.__version__):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
else:
if axis is None or axis == 0 or axis == (0,):
assert_almost_equal(getattr(obj, 'median')(axis=axis),
getattr(np, 'median')(a, axis=axis),
decimal=4)
......@@ -31,19 +31,19 @@ from nifty.operators.nifty_operators import power_operator
available = []
try:
from nifty import lm_space
from nifty import lm_space
except ImportError:
pass
else:
available += ['lm_space']
try:
from nifty import gl_space
from nifty import gl_space
except ImportError:
pass
else:
available += ['gl_space']
try:
from nifty import hp_space
from nifty import hp_space
except ImportError:
pass
else:
......@@ -1364,7 +1364,7 @@ class Test_axis(unittest.TestCase):
[None, (0,)],
DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_binary_operations(self, name, num, op, axis, datamodel):
def test_unary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel)
d = generate_data(s)
a = d.get_full_data()
......@@ -1375,4 +1375,6 @@ class Test_axis(unittest.TestCase):
getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4)
\ No newline at end of file
getattr(np, op)(a, axis=(0, 1)), decimal=4)
assert_almost_equal(s.unary_operation(d, op, axis=(1,)),
getattr(np, op)(a, axis=(1,)), decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment