Commit 82aff6f4 authored by Ultima's avatar Ultima
Browse files

Added d2o_init_checks to global configuration.

parent 124e6763
......@@ -57,6 +57,11 @@ variable_default_distribution_strategy = variable(
if (z == 'pyfftw') else True)
)
variable_d2o_init_checks = variable('d2o_init_checks',
[True, False],
lambda z: isinstance(z, bool),
genus='boolean')
global_configuration = configuration(
[variable_fft_module,
variable_lm2gl,
......@@ -64,7 +69,8 @@ global_configuration = configuration(
variable_use_libsharp,
variable_verbosity,
variable_mpi_module,
variable_default_distribution_strategy
variable_default_distribution_strategy,
variable_d2o_init_checks
],
path=os.path.expanduser('~') + "/.nifty/global_config")
......
......@@ -1008,12 +1008,23 @@ class _distributor_factory(object):
return return_dict
return_dict = {}
# Check that all nodes got the same distribution_strategy
strat_list = comm.allgather(distribution_strategy)
if all(x == strat_list[0] for x in strat_list) == False:
expensive_checks = gc['d2o_init_checks']
# Parse the MPI communicator
if comm is None:
raise ValueError(about._errors.cstring(
"ERROR: The distribution-strategy must be the same on " +
"all nodes!"))
"ERROR: The distributor needs MPI-communicator object comm!"))
else:
return_dict['comm'] = comm
if expensive_checks:
# Check that all nodes got the same distribution_strategy
strat_list = comm.allgather(distribution_strategy)
if all(x == strat_list[0] for x in strat_list) == False:
raise ValueError(about._errors.cstring(
"ERROR: The distribution-strategy must be the same on " +
"all nodes!"))
# Check for an hdf5 file and open it if given
if 'h5py' in gdi and alias is not None:
......@@ -1029,12 +1040,6 @@ class _distributor_factory(object):
else:
dset = None
# Parse the MPI communicator
if comm is None:
raise ValueError(about._errors.cstring(
"ERROR: The distributor needs MPI-communicator object comm!"))
else:
return_dict['comm'] = comm
# Parse the datatype
if distribution_strategy in ['not', 'equal', 'fftw'] and \
......@@ -1069,10 +1074,11 @@ class _distributor_factory(object):
else:
dtype = np.dtype(dtype)
dtype_list = comm.allgather(dtype)
if all(x == dtype_list[0] for x in dtype_list) == False:
raise ValueError(about._errors.cstring(
"ERROR: The given dtype must be the same on all nodes!"))
if expensive_checks:
dtype_list = comm.allgather(dtype)
if all(x == dtype_list[0] for x in dtype_list) == False:
raise ValueError(about._errors.cstring(
"ERROR: The given dtype must be the same on all nodes!"))
return_dict['dtype'] = dtype
# Parse the shape
......@@ -1092,10 +1098,13 @@ class _distributor_factory(object):
raise ValueError(about._errors.cstring(
"ERROR: global_shape == () is not a valid shape!"))
global_shape_list = comm.allgather(global_shape)
if not all(x == global_shape_list[0] for x in global_shape_list):
raise ValueError(about._errors.cstring(
"ERROR: The global_shape must be the same on all nodes!"))
if expensive_checks:
global_shape_list = comm.allgather(global_shape)
if not all(x == global_shape_list[0]
for x in global_shape_list):
raise ValueError(about._errors.cstring(
"ERROR: The global_shape must be the same on all " +
"nodes!"))
return_dict['global_shape'] = global_shape
# Case 2: local-type slicer
......@@ -1114,16 +1123,17 @@ class _distributor_factory(object):
raise ValueError(about._errors.cstring(
"ERROR: local_shape == () is not a valid shape!"))
local_shape_list = comm.allgather(local_shape[1:])
cleared_set = set(local_shape_list)
cleared_set.discard(())
if len(cleared_set) > 1:
# if not any(x == () for x in map(np.shape, local_shape_list)):
# if not all(x == local_shape_list[0] for x in
# local_shape_list):
raise ValueError(about._errors.cstring(
"ERROR: All but the first entry of local_shape must be " +
"the same on all nodes!"))
if expensive_checks:
local_shape_list = comm.allgather(local_shape[1:])
cleared_set = set(local_shape_list)
cleared_set.discard(())
if len(cleared_set) > 1:
# if not any(x == () for x in map(np.shape, local_shape_list)):
# if not all(x == local_shape_list[0] for x in
# local_shape_list):
raise ValueError(about._errors.cstring(
"ERROR: All but the first entry of local_shape " +
"must be the same on all nodes!"))
return_dict['local_shape'] = local_shape
# Add the name of the distributor if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment