Commit d7853ed7 authored by theos's avatar theos
Browse files

Fixed dependency on nifty.about.

parent af321e6c
Pipeline #3815 skipped
# -*- coding: utf-8 -*-
import numpy as np
from nifty import about
def cast_axis_to_tuple(axis):
......@@ -13,6 +12,6 @@ def cast_axis_to_tuple(axis):
if np.isscalar(axis):
axis = (int(axis), )
else:
raise TypeError(about._errors.cstring(
"ERROR: Could not convert axis-input to tuple of ints"))
raise TypeError(
"ERROR: Could not convert axis-input to tuple of ints")
return axis
......@@ -4,7 +4,7 @@ import numpy as np
from d2o.config import configuration as gc,\
dependency_injector as gdi
dependency_injector as gdi
from d2o_librarian import d2o_librarian
from cast_axis_to_tuple import cast_axis_to_tuple
......@@ -13,8 +13,11 @@ from strategies import STRATEGIES
MPI = gdi[gc['mpi_module']]
about_cstring = lambda z: z
from sys import stdout
warn_print = lambda z: stdout.write(z + "\n"); stdout.flush()
about_warnings_cprint = lambda z: stdout.write(z + "\n"); stdout.flush()
about_infos_cprint = lambda z: stdout.write(z + "\n"); stdout.flush()
class distributed_data_object(object):
"""A multidimensional array with modular MPI-based distribution schemes.
......@@ -124,7 +127,7 @@ class distributed_data_object(object):
only one local object on a specific node is given.
In order to speed up the init process the distributor_factory checks
if the global_configuration object gc yields gc['d2o_init_checks'] == True.
if the global_configuration object gc yields gc['mpi_init_checks'] == True.
If yes, all checks expensive checks are skipped; namely those which need
mpi communication. Use this in order to get a fast init speed without
loosing d2o's init parsing logic.
......@@ -412,7 +415,7 @@ class distributed_data_object(object):
try:
result_data = function(local_data)
except:
about.warnings.cprint(
about_warnings_cprint(
"WARNING: Trying to use np.vectorize!")
result_data = np.vectorize(function)(local_data)
......@@ -1301,7 +1304,7 @@ class distributed_data_object(object):
expensive.
"""
about.warnings.cprint(
about_warnings_cprint(
"WARNING: The current implementation of median is very expensive!")
median = np.median(self.get_full_data(), axis=axis, **kwargs)
if np.isscalar(median):
......@@ -1437,7 +1440,7 @@ class distributed_data_object(object):
if self.dtype not in [np.dtype('int16'), np.dtype('int32'),
np.dtype('int64'), np.dtype('uint16'),
np.dtype('uint32'), np.dtype('uint64')]:
raise TypeError(about._errors.cstring(
raise TypeError(about_cstring(
"ERROR: Distributed-data-object must be of integer datatype!"))
minlength = max(self.amax() + 1, minlength)
......
......@@ -4,9 +4,8 @@ import numbers
import numpy as np
from nifty.keepers import about,\
global_configuration as gc,\
global_dependency_injector as gdi
from d2o.config import configuration as gc,\
dependency_injector as gdi
from distributed_data_object import distributed_data_object
......@@ -24,6 +23,11 @@ h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
about_cstring = lambda z: z
from sys import stdout
about_infos_cprint = lambda z: stdout.write(z + "\n"); stdout.flush()
class _distributor_factory(object):
def __init__(self):
......@@ -48,11 +52,11 @@ class _distributor_factory(object):
return_dict = {}
expensive_checks = gc['d2o_init_checks']
expensive_checks = gc['mpi_init_checks']
# Parse the MPI communicator
if comm is None:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: The distributor needs MPI-communicator object comm!"))
else:
return_dict['comm'] = comm
......@@ -61,7 +65,7 @@ class _distributor_factory(object):
# Check that all nodes got the same distribution_strategy
strat_list = comm.allgather(distribution_strategy)
if all(x == strat_list[0] for x in strat_list) == False:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: The distribution-strategy must be the same on " +
"all nodes!"))
......@@ -88,7 +92,7 @@ class _distributor_factory(object):
if dtype is None:
if global_data is None:
dtype = np.dtype('float64')
about.infos.cprint('INFO: dtype set was set to default.')
about_infos_cprint('INFO: dtype set was set to default.')
else:
try:
dtype = global_data.dtype
......@@ -108,14 +112,14 @@ class _distributor_factory(object):
dtype = np.array(local_data).dtype
else:
dtype = np.dtype('float64')
about.infos.cprint('INFO: dtype set was set to default.')
about_infos_cprint('INFO: dtype set was set to default.')
else:
dtype = np.dtype(dtype)
if expensive_checks:
dtype_list = comm.allgather(dtype)
if all(x == dtype_list[0] for x in dtype_list) == False:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: The given dtype must be the same on all nodes!"))
return_dict['dtype'] = dtype
......@@ -129,18 +133,18 @@ class _distributor_factory(object):
elif global_shape is not None:
global_shape = tuple(global_shape)
else:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: Neither non-0-dimensional global_data nor " +
"global_shape nor hdf5 file supplied!"))
if global_shape == ():
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: global_shape == () is not a valid shape!"))
if expensive_checks:
global_shape_list = comm.allgather(global_shape)
if not all(x == global_shape_list[0]
for x in global_shape_list):
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: The global_shape must be the same on all " +
"nodes!"))
return_dict['global_shape'] = global_shape
......@@ -154,11 +158,11 @@ class _distributor_factory(object):
elif local_shape is not None:
local_shape = tuple(local_shape)
else:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: Neither non-0-dimensional local_data nor " +
"local_shape nor global d2o supplied!"))
if local_shape == ():
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: local_shape == () is not a valid shape!"))
if expensive_checks:
......@@ -166,7 +170,7 @@ class _distributor_factory(object):
cleared_set = set(local_shape_list)
cleared_set.discard(())
if len(cleared_set) > 1:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: All but the first entry of local_shape " +
"must be the same on all nodes!"))
return_dict['local_shape'] = local_shape
......@@ -210,7 +214,7 @@ class _distributor_factory(object):
def get_distributor(self, distribution_strategy, comm, **kwargs):
# check if the distribution strategy is known
if distribution_strategy not in STRATEGIES['all']:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: Unknown distribution strategy supplied."))
# parse the kwargs
......@@ -273,7 +277,7 @@ def _infer_key_type(key):
found = 'd2o'
found_boolean = (key.dtype == np.bool_)
else:
raise ValueError(about._errors.cstring("ERROR: Unknown keytype!"))
raise ValueError(about_cstring("ERROR: Unknown keytype!"))
return (found, found_boolean)
......@@ -395,7 +399,7 @@ class _slicing_distributor(distributor):
self._my_dtype_converter = dtype_converter
if not self._my_dtype_converter.known_np_Q(self.dtype):
raise TypeError(about._errors.cstring(
raise TypeError(about_cstring(
"ERROR: The datatype " + str(self.dtype.__repr__()) +
" is not known to mpi4py."))
......@@ -457,7 +461,7 @@ class _slicing_distributor(distributor):
local_data = np.array(local_data).astype(
self.dtype, copy=copy).reshape(self.local_shape)
else:
raise TypeError(about._errors.cstring(
raise TypeError(about_cstring(
"ERROR: Unknown istribution strategy"))
return (local_data, hermitian)
......@@ -467,7 +471,7 @@ class _slicing_distributor(distributor):
def globalize_index(self, index):
index = np.array(index, dtype=np.int).flatten()
if index.shape != (len(self.global_shape),):
raise TypeError(about._errors.cstring("ERROR: Length\
raise TypeError(about_cstring("ERROR: Length\
of index tuple does not match the array's shape!"))
globalized_index = index
globalized_index[0] = index[0] + self.local_start
......@@ -477,7 +481,7 @@ class _slicing_distributor(distributor):
-np.array(self.global_shape),
np.array(self.global_shape) - 1)
if np.any(global_index_memory != globalized_index):
about.warnings.cprint("WARNING: Indices were clipped!")
about_infos_cprint("WARNING: Indices were clipped!")
globalized_index = tuple(globalized_index)
return globalized_index
......@@ -680,7 +684,7 @@ class _slicing_distributor(distributor):
**kwargs)
else:
if from_key is not None:
about.infos.cprint(
about_infos_cprint(
"INFO: Advanced injection is not available for this " +
"combination of to_key and from_key.")
prepared_data_update = data_update[from_key]
......@@ -699,7 +703,7 @@ class _slicing_distributor(distributor):
# Case 2.1: The array is boolean.
if to_found_boolean:
if from_key is not None:
about.infos.cprint(
about_infos_cprint(
"INFO: Advanced injection is not available for this " +
"combination of to_key and from_key.")
prepared_data_update = data_update[from_key]
......@@ -715,7 +719,7 @@ class _slicing_distributor(distributor):
# advanced slicing is supported.
else:
if len(to_key.shape) != 1:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"WARNING: Only one-dimensional advanced indexing " +
"is supported"))
# Make a recursive call in order to trigger the 'list'-section
......@@ -728,7 +732,7 @@ class _slicing_distributor(distributor):
# one-dimensional advanced indexing list.
elif to_found == 'indexinglist':
if from_key is not None:
about.infos.cprint(
about_infos_cprint(
"INFO: Advanced injection is not available for this " +
"combination of to_key and from_key.")
prepared_data_update = data_update[from_key]
......@@ -813,7 +817,7 @@ class _slicing_distributor(distributor):
if to_step is None:
to_step = 1
elif to_step == 0:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: to_step size == 0!"))
# Compute the offset of the data the individual node will take.
......@@ -851,7 +855,7 @@ class _slicing_distributor(distributor):
shifted_stop=data_update.shape[0],
global_length=data_update.shape[0])
if from_slices_start is None:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: _backshift_and_decycle should never return " +
"None for local_start!"))
......@@ -860,7 +864,7 @@ class _slicing_distributor(distributor):
if from_step is None:
from_step = 1
elif from_step == 0:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: from_step size == 0!"))
localized_from_start = from_slices_start + from_step * o[r]
......@@ -969,7 +973,7 @@ class _slicing_distributor(distributor):
# advanced slicing is supported.
else:
if len(key.shape) != 1:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"WARNING: Only one-dimensional advanced indexing " +
"is supported"))
# Make a recursive call in order to trigger the 'list'-section
......@@ -986,7 +990,7 @@ class _slicing_distributor(distributor):
def collect_data_from_list(self, data, list_key, copy=True, **kwargs):
if list_key == []:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: key == [] is an unsupported key!"))
local_list_key = self._advanced_index_decycler(list_key)
local_result = data[local_list_key]
......@@ -1012,7 +1016,7 @@ class _slicing_distributor(distributor):
# if the index is still negative, or it is greater than
# global_length the index is ill-choosen
if zeroth_key < 0 or zeroth_key >= global_length:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: Index out of bounds!"))
# shift the index
local_zeroth_key = zeroth_key - shift
......@@ -1037,7 +1041,7 @@ class _slicing_distributor(distributor):
# if there are still negative indices, or indices greater than
# global_length the indices are ill-choosen
if (zeroth_key < 0).any() or (zeroth_key >= global_length).any():
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: Index out of bounds!"))
# shift the indices according to shift
shift_list = self.comm.allgather(shift)
......@@ -1064,7 +1068,7 @@ class _slicing_distributor(distributor):
# TODO: Implement fast check!
# if not all(result[i] <= result[i + 1]
# for i in xrange(len(result) - 1)):
# raise ValueError(about._errors.cstring(
# raise ValueError(about_cstring(
# "ERROR: The first dimemnsion of list_key must be sorted!"))
result = [result]
......@@ -1092,7 +1096,7 @@ class _slicing_distributor(distributor):
# if there are still negative indices, or indices greater than
# global_length the indices are ill-choosen
if (zeroth_key < 0).any() or (zeroth_key >= global_length).any():
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: Index out of bounds!"))
# shift the indices according to shift
local_zeroth_key = zeroth_key - shift
......@@ -1106,7 +1110,7 @@ class _slicing_distributor(distributor):
# TODO: Implement fast check!
# if not all(result[0][i] <= result[0][i + 1]
# for i in xrange(len(result[0]) - 1)):
# raise ValueError(about._errors.cstring(
# raise ValueError(about_cstring(
# "ERROR: The first dimemnsion of list_key must be sorted!"))
for ii in xrange(1, len(from_list_key)):
......@@ -1307,7 +1311,7 @@ class _slicing_distributor(distributor):
if np.all(global_matchQ):
extracted_data = data_object[:]
else:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: supplied shapes do neither match globally " +
"nor locally"))
......@@ -1394,7 +1398,7 @@ class _slicing_distributor(distributor):
# check if the dimensions match
if len(self.global_shape) != len(foreign.shape):
raise ValueError(
about._errors.cstring("ERROR: unequal number of dimensions!"))
about_cstring("ERROR: unequal number of dimensions!"))
# check direct matches
direct_match = (np.array(self.global_shape) == np.array(foreign.shape))
# check broadcast compatibility
......@@ -1404,7 +1408,7 @@ class _slicing_distributor(distributor):
combined_match = (direct_match | broadcast_match)
if not np.all(combined_match):
raise ValueError(
about._errors.cstring("ERROR: incompatible shapes!"))
about_cstring("ERROR: incompatible shapes!"))
matching_dimensions = tuple(direct_match)
return (foreign, matching_dimensions)
......@@ -1736,7 +1740,7 @@ class _slicing_distributor(distributor):
f[alias]
# if yes, and overwriteQ is set to False, raise an Error
if overwriteQ is False:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: overwriteQ is False, but alias already " +
"in use!"))
else: # if yes, remove the existing dataset
......@@ -1765,12 +1769,12 @@ class _slicing_distributor(distributor):
dset = f[alias]
# check shape
if dset.shape != self.global_shape:
raise TypeError(about._errors.cstring(
raise TypeError(about_cstring(
"ERROR: The shape of the given dataset does not match " +
"the distributed_data_object."))
# check dtype
if dset.dtype != self.dtype:
raise TypeError(about._errors.cstring(
raise TypeError(about_cstring(
"ERROR: The datatype of the given dataset does not " +
"match the one of the distributed_data_object."))
# if everything seems to fit, load the data
......@@ -1780,11 +1784,11 @@ class _slicing_distributor(distributor):
return data
else:
def save_data(self, *args, **kwargs):
raise ImportError(about._errors.cstring(
raise ImportError(about_cstring(
"ERROR: h5py is not available"))
def load_data(self, *args, **kwargs):
raise ImportError(about._errors.cstring(
raise ImportError(about_cstring(
"ERROR: h5py is not available"))
def get_iter(self, d2o):
......@@ -1823,9 +1827,9 @@ def _freeform_slicer(comm, local_shape):
cleared_set.discard(())
if len(cleared_set) > 1:
raise ValueError(about._errors.cstring("ERROR: All but the first " +
"dimensions of local_shape " +
"must be the same!"))
raise ValueError(about_cstring("ERROR: All but the first " +
"dimensions of local_shape " +
"must be the same!"))
if local_shape == ():
first_shape_index = 0
else:
......@@ -2043,7 +2047,7 @@ class _not_distributor(distributor):
f[alias]
# if yes, and overwriteQ is set to False, raise an Error
if overwriteQ is False:
raise ValueError(about._errors.cstring(
raise ValueError(about_cstring(
"ERROR: overwriteQ == False, but alias already " +
"in use!"))
else: # if yes, remove the existing dataset
......@@ -2072,12 +2076,12 @@ class _not_distributor(distributor):
dset = f[alias]
# check shape
if dset.shape != self.global_shape:
raise TypeError(about._errors.cstring(
raise TypeError(about_cstring(
"ERROR: The shape of the given dataset does not match " +
"the distributed_data_object."))
# check dtype
if dset.dtype != self.dtype:
raise TypeError(about._errors.cstring(
raise TypeError(about_cstring(
"ERROR: The datatype of the given dataset does not " +
"match the distributed_data_object."))
# if everything seems to fit, load the data
......@@ -2087,11 +2091,11 @@ class _not_distributor(distributor):
return data
else:
def save_data(self, *args, **kwargs):
raise ImportError(about._errors.cstring(
raise ImportError(about_cstring(
"ERROR: h5py is not available"))
def load_data(self, *args, **kwargs):
raise ImportError(about._errors.cstring(
raise ImportError(about_cstring(
"ERROR: h5py is not available"))
def get_iter(self, d2o):
......
......@@ -2,8 +2,8 @@
import numpy as np
from nifty.keepers import global_configuration as gc,\
global_dependency_injector as gdi
from d2o.config import configuration as gc,\
dependency_injector as gdi
MPI = gdi[gc['mpi_module']]
......
# -*- coding: utf-8 -*-
from nifty.keepers import global_dependency_injector as gdi
from d2o.config import dependency_injector as gdi
pyfftw = gdi.get('pyfftw')
......
......@@ -2,8 +2,8 @@
import numpy as np
from nifty.keepers import global_configuration as gc,\
global_dependency_injector as gdi
from d2o.config import configuration as gc,\
dependency_injector as gdi
MPI = gdi[gc['mpi_module']]
......
......@@ -13,9 +13,9 @@ import numpy as np
import warnings
import tempfile
import nifty
from nifty.d2o import distributed_data_object,\
STRATEGIES
import d2o
from d2o import distributed_data_object,\
STRATEGIES
from distutils.version import LooseVersion as lv
......@@ -81,10 +81,10 @@ comparison_operators = ['__ne__', '__lt__', '__le__', '__eq__', '__ge__',
###############################################################################
hdf5_test_paths = [ # ('hdf5_init_test.hdf5', None),
('hdf5_init_test.hdf5', os.path.join(os.path.dirname(nifty.__file__),
('hdf5_init_test.hdf5', os.path.join(os.path.dirname(d2o.__file__),
'test/hdf5_init_test.hdf5')),
('hdf5_init_test.hdf5',
os.path.join(os.path.dirname(nifty.__file__),
os.path.join(os.path.dirname(d2o.__file__),
'test//hdf5_test_folder/hdf5_init_test.hdf5'))]
###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment