Commit e10b54aa authored by Ultima's avatar Ultima
Browse files

Cleaned up the field class.

Fixed a bug in nifty_mpi_data.
parent 41dd2ed7
This diff is collapsed.
...@@ -1294,6 +1294,7 @@ class distributor(object): ...@@ -1294,6 +1294,7 @@ class distributor(object):
distribution_strategy='freeform', distribution_strategy='freeform',
comm=self.comm) comm=self.comm)
# disperse the data one after another # disperse the data one after another
print ('i', i, temp_data_update)
self._disperse_data_primitive( self._disperse_data_primitive(
data=data, data=data,
to_key=to_key_list[i], to_key=to_key_list[i],
...@@ -2241,40 +2242,50 @@ class _slicing_distributor(distributor): ...@@ -2241,40 +2242,50 @@ class _slicing_distributor(distributor):
# in case of leading scalars, indenify the node with data # in case of leading scalars, indenify the node with data
# and broadcast the shape to the others # and broadcast the shape to the others
if sliceified[0]: if sliceified[0]:
local_has_data = (np.prod( # Case 1: The in_data d2o has more than one dimension
np.shape(in_data.get_local_data()) if len(in_data.shape) > 1:
) != 0) local_has_data = (np.prod(
local_has_data_list = np.array(self.comm.allgather( np.shape(
local_has_data)) in_data.get_local_data())) != 0)
nodes_with_data = np.where(local_has_data_list == True)[0] local_has_data_list = np.array(
if np.shape(nodes_with_data)[0] > 1: self.comm.allgather(local_has_data))
raise ValueError( nodes_with_data = np.where(local_has_data_list)[0]
"ERROR: scalar index on first dimension, but more " + if np.shape(nodes_with_data)[0] > 1:
"than one node has data!") raise ValueError(
elif np.shape(nodes_with_data)[0] == 1: "ERROR: scalar index on first dimension, but " +
node_with_data = nodes_with_data[0] " more than one node has data!")
else: elif np.shape(nodes_with_data)[0] == 1:
node_with_data = -1 node_with_data = nodes_with_data[0]
else:
node_with_data = -1
if node_with_data == -1: if node_with_data == -1:
broadcasted_shape = (0,) * len(temp_local_shape) broadcasted_shape = (0,) * len(temp_local_shape)
else: else:
broadcasted_shape = self.comm.bcast(temp_local_shape, broadcasted_shape = self.comm.bcast(
temp_local_shape,
root=node_with_data) root=node_with_data)
if self.comm.rank != node_with_data: if self.comm.rank != node_with_data:
temp_local_shape = np.array(broadcasted_shape) temp_local_shape = np.array(broadcasted_shape)
temp_local_shape[0] = 0 temp_local_shape[0] = 0
temp_local_shape = tuple(temp_local_shape) temp_local_shape = tuple(temp_local_shape)
# Case 2: The in_data d2o is only onedimensional
else:
# The data contained in the d2o must be stored on one
# single node at the end. Hence it is ok to consolidate
# the data and to make a recursive call.
temp_data = in_data.get_full_data()
return self._enfold(temp_data, sliceified)
if in_data.distribution_strategy != 'freeform': if in_data.distribution_strategy in STRATEGIES['global']:
new_data = in_data.copy_empty(global_shape=temp_global_shape) new_data = in_data.copy_empty(global_shape=temp_global_shape)
new_data.set_local_data(local_data, copy=False) new_data.set_local_data(local_data, copy=False)
else: elif in_data.distribution_strategy in STRATEGIES['local']:
reshaped_data = local_data.reshape(temp_local_shape) reshaped_data = local_data.reshape(temp_local_shape)
new_data = distributed_data_object( new_data = distributed_data_object(
local_data=reshaped_data, local_data=reshaped_data,
distribution_strategy='freeform', distribution_strategy=in_data.distribution_strategy,
comm=self.comm) comm=self.comm)
return new_data return new_data
else: else:
return local_data.reshape(temp_local_shape) return local_data.reshape(temp_local_shape)
......
...@@ -743,15 +743,33 @@ class Test_slicing_get_set_data(unittest.TestCase): ...@@ -743,15 +743,33 @@ class Test_slicing_get_set_data(unittest.TestCase):
############################################################################### ###############################################################################
@parameterized.expand(all_distribution_strategies) @parameterized.expand(all_distribution_strategies,
testcase_func_name=custom_name_func)
def test_get_single_value_from_d2o(self, distribution_strategy): def test_get_single_value_from_d2o(self, distribution_strategy):
(a, obj) = generate_data((4,), np.dtype('float'), (a, obj) = generate_data((4,), np.dtype('float'),
distribution_strategy) distribution_strategy)
assert_equal(obj[0], a[0]) assert_equal(obj[0], a[0])
############################################################################### ###############################################################################
###############################################################################
@parameterized.expand(
itertools.product(all_distribution_strategies,
all_distribution_strategies),
testcase_func_name=custom_name_func)
def test_single_row_from_d2o(self, distribution_strategy1,
distribution_strategy2):
(a, obj) = generate_data((8, 8), np.dtype('float'),
distribution_strategy1)
(b, p) = generate_data((8,), np.dtype('float'),
distribution_strategy2)
a[4] = b
obj[4] = p
assert_equal(obj.get_full_data(), a)
###############################################################################
###############################################################################
class Test_boolean_get_set_data(unittest.TestCase): class Test_boolean_get_set_data(unittest.TestCase):
......
...@@ -22,16 +22,16 @@ from nifty.nifty_paradict import space_paradict ...@@ -22,16 +22,16 @@ from nifty.nifty_paradict import space_paradict
from nifty.nifty_core import POINT_DISTRIBUTION_STRATEGIES from nifty.nifty_core import POINT_DISTRIBUTION_STRATEGIES
from nifty.rg.nifty_rg import RG_DISTRIBUTION_STRATEGIES,\ from nifty.rg.nifty_rg import RG_DISTRIBUTION_STRATEGIES,\
gc as RG_GC gc as RG_GC
from nifty.lm.nifty_lm import LM_DISTRIBUTION_STRATEGIES,\ from nifty.lm.nifty_lm import LM_DISTRIBUTION_STRATEGIES,\
GL_DISTRIBUTION_STRATEGIES,\ GL_DISTRIBUTION_STRATEGIES,\
HP_DISTRIBUTION_STRATEGIES HP_DISTRIBUTION_STRATEGIES
from nifty.nifty_power_indices import power_indices from nifty.nifty_power_indices import power_indices
from nifty.nifty_utilities import _hermitianize_inverter as \ from nifty.nifty_utilities import _hermitianize_inverter as \
hermitianize_inverter hermitianize_inverter
###############################################################################
###############################################################################
def custom_name_func(testcase_func, param_num, param): def custom_name_func(testcase_func, param_num, param):
return "%s_%s" % ( return "%s_%s" % (
...@@ -169,10 +169,10 @@ def check_almost_equality(space, data1, data2, integers=7): ...@@ -169,10 +169,10 @@ def check_almost_equality(space, data1, data2, integers=7):
def flip(space, data): def flip(space, data):
return space.unary_operation(hermitianize_inverter(data), 'conjugate') return space.unary_operation(hermitianize_inverter(data), 'conjugate')
############################################################################### ###############################################################################
############################################################################### ###############################################################################
class Test_Common_Space_Features(unittest.TestCase): class Test_Common_Space_Features(unittest.TestCase):
@parameterized.expand(all_spaces, @parameterized.expand(all_spaces,
...@@ -195,7 +195,6 @@ class Test_Common_Space_Features(unittest.TestCase): ...@@ -195,7 +195,6 @@ class Test_Common_Space_Features(unittest.TestCase):
assert(callable(s.apply_scalar_function)) assert(callable(s.apply_scalar_function))
assert(callable(s.unary_operation)) assert(callable(s.unary_operation))
assert(callable(s.binary_operation)) assert(callable(s.binary_operation))
assert(callable(s.get_norm))
assert(callable(s.get_shape)) assert(callable(s.get_shape))
assert(callable(s.get_dim)) assert(callable(s.get_dim))
assert(callable(s.get_dof)) assert(callable(s.get_dof))
...@@ -207,6 +206,7 @@ class Test_Common_Space_Features(unittest.TestCase): ...@@ -207,6 +206,7 @@ class Test_Common_Space_Features(unittest.TestCase):
assert(callable(s.get_random_values)) assert(callable(s.get_random_values))
assert(callable(s.calc_weight)) assert(callable(s.calc_weight))
assert(callable(s.get_weight)) assert(callable(s.get_weight))
assert(callable(s.calc_norm))
assert(callable(s.calc_dot)) assert(callable(s.calc_dot))
assert(callable(s.calc_transform)) assert(callable(s.calc_transform))
assert(callable(s.calc_smooth)) assert(callable(s.calc_smooth))
...@@ -346,18 +346,6 @@ class Test_Point_Space(unittest.TestCase): ...@@ -346,18 +346,6 @@ class Test_Point_Space(unittest.TestCase):
s.binary_operation(d, d2, op) s.binary_operation(d, d2, op)
# TODO: Implement value verification # TODO: Implement value verification
###############################################################################
@parameterized.expand(
itertools.product(DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_get_norm(self, datamodel):
num = 10
s = point_space(num, datamodel=datamodel)
d = s.cast(np.arange(num))
assert_almost_equal(s.get_norm(d), 16.881943016134134)
assert_almost_equal(s.get_norm(d, q=3), 12.651489979526238)
############################################################################### ###############################################################################
@parameterized.expand( @parameterized.expand(
...@@ -599,6 +587,18 @@ class Test_Point_Space(unittest.TestCase): ...@@ -599,6 +587,18 @@ class Test_Point_Space(unittest.TestCase):
assert_equal(s.calc_dot(1, 1), num) assert_equal(s.calc_dot(1, 1), num)
assert_equal(s.calc_dot(np.arange(num), 1), num * (num - 1.) / 2.) assert_equal(s.calc_dot(np.arange(num), 1), num * (num - 1.) / 2.)
###############################################################################
@parameterized.expand(
itertools.product(DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_calc_norm(self, datamodel):
num = 10
s = point_space(num, datamodel=datamodel)
d = s.cast(np.arange(num))
assert_almost_equal(s.calc_norm(d), 16.881943016134134)
assert_almost_equal(s.calc_norm(d, q=3), 12.651489979526238)
############################################################################### ###############################################################################
@parameterized.expand( @parameterized.expand(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment