Commit e10b54aa authored by Ultima's avatar Ultima
Browse files

Cleaned up the field class.

Fixed a bug in nifty_mpi_data.
parent 41dd2ed7
This diff is collapsed.
......@@ -1294,6 +1294,7 @@ class distributor(object):
distribution_strategy='freeform',
comm=self.comm)
# disperse the data one after another
print ('i', i, temp_data_update)
self._disperse_data_primitive(
data=data,
to_key=to_key_list[i],
......@@ -2241,40 +2242,50 @@ class _slicing_distributor(distributor):
# in case of leading scalars, indenify the node with data
# and broadcast the shape to the others
if sliceified[0]:
local_has_data = (np.prod(
np.shape(in_data.get_local_data())
) != 0)
local_has_data_list = np.array(self.comm.allgather(
local_has_data))
nodes_with_data = np.where(local_has_data_list == True)[0]
if np.shape(nodes_with_data)[0] > 1:
raise ValueError(
"ERROR: scalar index on first dimension, but more " +
"than one node has data!")
elif np.shape(nodes_with_data)[0] == 1:
node_with_data = nodes_with_data[0]
else:
node_with_data = -1
# Case 1: The in_data d2o has more than one dimension
if len(in_data.shape) > 1:
local_has_data = (np.prod(
np.shape(
in_data.get_local_data())) != 0)
local_has_data_list = np.array(
self.comm.allgather(local_has_data))
nodes_with_data = np.where(local_has_data_list)[0]
if np.shape(nodes_with_data)[0] > 1:
raise ValueError(
"ERROR: scalar index on first dimension, but " +
" more than one node has data!")
elif np.shape(nodes_with_data)[0] == 1:
node_with_data = nodes_with_data[0]
else:
node_with_data = -1
if node_with_data == -1:
broadcasted_shape = (0,) * len(temp_local_shape)
else:
broadcasted_shape = self.comm.bcast(temp_local_shape,
if node_with_data == -1:
broadcasted_shape = (0,) * len(temp_local_shape)
else:
broadcasted_shape = self.comm.bcast(
temp_local_shape,
root=node_with_data)
if self.comm.rank != node_with_data:
temp_local_shape = np.array(broadcasted_shape)
temp_local_shape[0] = 0
temp_local_shape = tuple(temp_local_shape)
if self.comm.rank != node_with_data:
temp_local_shape = np.array(broadcasted_shape)
temp_local_shape[0] = 0
temp_local_shape = tuple(temp_local_shape)
# Case 2: The in_data d2o is only onedimensional
else:
# The data contained in the d2o must be stored on one
# single node at the end. Hence it is ok to consolidate
# the data and to make a recursive call.
temp_data = in_data.get_full_data()
return self._enfold(temp_data, sliceified)
if in_data.distribution_strategy != 'freeform':
if in_data.distribution_strategy in STRATEGIES['global']:
new_data = in_data.copy_empty(global_shape=temp_global_shape)
new_data.set_local_data(local_data, copy=False)
else:
elif in_data.distribution_strategy in STRATEGIES['local']:
reshaped_data = local_data.reshape(temp_local_shape)
new_data = distributed_data_object(
local_data=reshaped_data,
distribution_strategy='freeform',
comm=self.comm)
local_data=reshaped_data,
distribution_strategy=in_data.distribution_strategy,
comm=self.comm)
return new_data
else:
return local_data.reshape(temp_local_shape)
......
......@@ -743,15 +743,33 @@ class Test_slicing_get_set_data(unittest.TestCase):
###############################################################################
@parameterized.expand(all_distribution_strategies)
@parameterized.expand(all_distribution_strategies,
testcase_func_name=custom_name_func)
def test_get_single_value_from_d2o(self, distribution_strategy):
(a, obj) = generate_data((4,), np.dtype('float'),
distribution_strategy)
assert_equal(obj[0], a[0])
###############################################################################
###############################################################################
@parameterized.expand(
itertools.product(all_distribution_strategies,
all_distribution_strategies),
testcase_func_name=custom_name_func)
def test_single_row_from_d2o(self, distribution_strategy1,
distribution_strategy2):
(a, obj) = generate_data((8, 8), np.dtype('float'),
distribution_strategy1)
(b, p) = generate_data((8,), np.dtype('float'),
distribution_strategy2)
a[4] = b
obj[4] = p
assert_equal(obj.get_full_data(), a)
###############################################################################
###############################################################################
class Test_boolean_get_set_data(unittest.TestCase):
......
......@@ -22,16 +22,16 @@ from nifty.nifty_paradict import space_paradict
from nifty.nifty_core import POINT_DISTRIBUTION_STRATEGIES
from nifty.rg.nifty_rg import RG_DISTRIBUTION_STRATEGIES,\
gc as RG_GC
gc as RG_GC
from nifty.lm.nifty_lm import LM_DISTRIBUTION_STRATEGIES,\
GL_DISTRIBUTION_STRATEGIES,\
HP_DISTRIBUTION_STRATEGIES
GL_DISTRIBUTION_STRATEGIES,\
HP_DISTRIBUTION_STRATEGIES
from nifty.nifty_power_indices import power_indices
from nifty.nifty_utilities import _hermitianize_inverter as \
hermitianize_inverter
hermitianize_inverter
###############################################################################
###############################################################################
def custom_name_func(testcase_func, param_num, param):
return "%s_%s" % (
......@@ -169,10 +169,10 @@ def check_almost_equality(space, data1, data2, integers=7):
def flip(space, data):
return space.unary_operation(hermitianize_inverter(data), 'conjugate')
###############################################################################
###############################################################################
class Test_Common_Space_Features(unittest.TestCase):
@parameterized.expand(all_spaces,
......@@ -195,7 +195,6 @@ class Test_Common_Space_Features(unittest.TestCase):
assert(callable(s.apply_scalar_function))
assert(callable(s.unary_operation))
assert(callable(s.binary_operation))
assert(callable(s.get_norm))
assert(callable(s.get_shape))
assert(callable(s.get_dim))
assert(callable(s.get_dof))
......@@ -207,6 +206,7 @@ class Test_Common_Space_Features(unittest.TestCase):
assert(callable(s.get_random_values))
assert(callable(s.calc_weight))
assert(callable(s.get_weight))
assert(callable(s.calc_norm))
assert(callable(s.calc_dot))
assert(callable(s.calc_transform))
assert(callable(s.calc_smooth))
......@@ -346,18 +346,6 @@ class Test_Point_Space(unittest.TestCase):
s.binary_operation(d, d2, op)
# TODO: Implement value verification
###############################################################################
@parameterized.expand(
itertools.product(DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_get_norm(self, datamodel):
num = 10
s = point_space(num, datamodel=datamodel)
d = s.cast(np.arange(num))
assert_almost_equal(s.get_norm(d), 16.881943016134134)
assert_almost_equal(s.get_norm(d, q=3), 12.651489979526238)
###############################################################################
@parameterized.expand(
......@@ -599,6 +587,18 @@ class Test_Point_Space(unittest.TestCase):
assert_equal(s.calc_dot(1, 1), num)
assert_equal(s.calc_dot(np.arange(num), 1), num * (num - 1.) / 2.)
###############################################################################
@parameterized.expand(
itertools.product(DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_calc_norm(self, datamodel):
num = 10
s = point_space(num, datamodel=datamodel)
d = s.cast(np.arange(num))
assert_almost_equal(s.calc_norm(d), 16.881943016134134)
assert_almost_equal(s.calc_norm(d, q=3), 12.651489979526238)
###############################################################################
@parameterized.expand(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment