Commit d7045735 authored by Ultima's avatar Ultima
Browse files

Made some code-cleanups to nifty_mpi_data.py.

parent 10e11a0a
......@@ -21,26 +21,23 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
##initialize the 'FOUND-packages'-dictionary
FOUND = {}
import numpy as np
from nifty_about import about
from weakref import WeakValueDictionary as weakdict
# initialize the 'FOUND-packages'-dictionary
FOUND = {}
try:
from mpi4py import MPI
FOUND['MPI'] = True
except(ImportError):
except(ImportError):
import mpi_dummy as MPI
FOUND['MPI'] = False
try:
import pyfftw
FOUND['pyfftw'] = True
except(ImportError):
except(ImportError):
FOUND['pyfftw'] = False
try:
......@@ -59,6 +56,7 @@ HDF5_DISTRIBUTION_STRATEGIES = ['equal', 'fftw']
COMM = MPI.COMM_WORLD
class distributed_data_object(object):
"""
......@@ -67,21 +65,21 @@ class distributed_data_object(object):
Parameters
----------
global_data : {tuple, list, numpy.ndarray} *at least 1-dimensional*
Initial data which will be casted to a numpy.ndarray and then
Initial data which will be casted to a numpy.ndarray and then
stored according to the distribution strategy. The global_data's
shape overwrites global_shape.
global_shape : tuple of ints, *optional*
If no global_data is supplied, global_shape can be used to
initialize an empty distributed_data_object
dtype : type, *optional*
If an explicit dtype is supplied, the given global_data will be
casted to it.
If an explicit dtype is supplied, the given global_data will be
casted to it.
distribution_strategy : {'fftw' (default), 'not'}, *optional*
Specifies the way, how global_data will be distributed to the
individual nodes.
Specifies the way, how global_data will be distributed to the
individual nodes.
'fftw' follows the distribution strategy of pyfftw.
'not' does not distribute the data at all.
'not' does not distribute the data at all.
Attributes
----------
......@@ -92,99 +90,34 @@ class distributed_data_object(object):
distribution_strategy : string
Name of the used distribution_strategy
distributor : distributor
The distributor object which takes care of all distribution and
consolidation of the data.
The distributor object which takes care of all distribution and
consolidation of the data.
shape : tuple of int
The global shape of the data
Raises
------
TypeError :
If the supplied distribution strategy is not known.
TypeError :
If the supplied distribution strategy is not known.
"""
def __init__(self, global_data = None, global_shape=None, dtype=None,
def __init__(self, global_data=None, global_shape=None, dtype=None,
local_data=None, local_shape=None,
distribution_strategy='fftw', hermitian=False,
alias=None, path=None, comm = MPI.COMM_WORLD,
copy = True, *args, **kwargs):
#
# ## a given hdf5 file overwrites the other parameters
# if FOUND['h5py'] == True and alias is not None:
# ## set file path
# file_path = path if (path is not None) else alias
# ## open hdf5 file
# if FOUND['h5py_parallel'] == True and FOUND['MPI'] == True:
# f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
# else:
# f= h5py.File(file_path, 'r')
# ## open alias in file
# dset = f[alias]
# ## set shape
# global_shape = dset.shape
# ## set dtype
# dtype = dset.dtype.type
# ## if no hdf5 path was given, extract global_shape and dtype from
# ## the remaining arguments
# else:
# ## an explicitly given dtype overwrites the one from global_data
# if dtype is None:
# if global_data is None and local_data is None:
# raise ValueError(about._errors.cstring(
# "ERROR: Neither global_data nor local_data nor dtype supplied!"))
# elif global_data is not None:
# try:
# dtype = global_data.dtype.type
# except(AttributeError):
# try:
# dtype = global_data.dtype
# except(AttributeError):
# dtype = np.array(global_data).dtype.type
# elif local_data is not None:
# try:
# dtype = local_data.dtype.type
# except(AttributeError):
# try:
# dtype = local_data.dtype
# except(AttributeError):
# dtype = np.array(local_data).dtype.type
# else:
# dtype = np.dtype(dtype).type
#
# ## an explicitly given global_shape argument is only used if
# ## 1. no global_data was supplied, or
# ## 2. global_data is a scalar/list of dimension 0.
#
# if global_data is not None and np.isscalar(global_data) == False:
# global_shape = global_data.shape
# elif global_shape is not None:
# global_shape = tuple(global_shape)
#
# if local_data is not None
#
## if global_shape is None:
## if global_data is None or np.isscalar(global_data):
## raise ValueError(about._errors.cstring(
## "ERROR: Neither non-0-dimensional global_data nor global_shape supplied!"))
## global_shape = global_data.shape
## else:
## if global_data is None or np.isscalar(global_data):
## global_shape = tuple(global_shape)
## else:
## global_shape = global_data.shape
## TODO: allow init with empty shape
alias=None, path=None, comm=MPI.COMM_WORLD,
copy=True, *args, **kwargs):
# TODO: allow init with empty shape
if isinstance(global_data, tuple) or isinstance(global_data, list):
global_data = np.array(global_data, copy=False)
if isinstance(local_data, tuple) or isinstance(local_data, list):
local_data = np.array(local_data, copy=False)
self.distributor = distributor_factory.get_distributor(
distribution_strategy = distribution_strategy,
comm = comm,
global_data = global_data,
global_data = global_data,
global_shape = global_shape,
local_data = local_data,
local_shape = local_shape,
......@@ -192,14 +125,14 @@ class distributed_data_object(object):
path = path,
dtype = dtype,
**kwargs)
self.distribution_strategy = distribution_strategy
self.dtype = self.distributor.dtype
self.shape = self.distributor.global_shape
self.local_shape = self.distributor.local_shape
self.comm = self.distributor.comm
self.init_args = args
self.init_args = args
self.init_kwargs = kwargs
(self.data, self.hermitian) = self.distributor.initialize_data(
......@@ -210,36 +143,20 @@ class distributed_data_object(object):
hermitian = hermitian,
copy = copy)
self.index = d2o_librarian.register(self)
# ## If a hdf5 path was given, load the data
# if FOUND['h5py'] == True and alias is not None:
# self.load(alias = alias, path = path)
# ## close the file handle
# f.close()
#
# ## If the input data was a scalar, set the whole array to this value
# elif global_data is not None and np.isscalar(global_data):
# temp = np.empty(self.distributor.local_shape, dtype = self.dtype)
# temp.fill(global_data)
# self.set_local_data(temp)
# self.hermitian = True
# else:
# self.set_full_data(data=global_data, hermitian=hermitian,
# copy = copy, **kwargs)
#
def copy(self, dtype=None, distribution_strategy=None, **kwargs):
temp_d2o = self.copy_empty(dtype=dtype,
distribution_strategy=distribution_strategy,
**kwargs)
temp_d2o = self.copy_empty(dtype=dtype,
distribution_strategy=distribution_strategy,
**kwargs)
if distribution_strategy is None or \
distribution_strategy == self.distribution_strategy:
distribution_strategy == self.distribution_strategy:
temp_d2o.set_local_data(self.get_local_data(), copy=True)
else:
#temp_d2o.set_full_data(self.get_full_data())
temp_d2o.inject((slice(None),), self, (slice(None),))
temp_d2o.hermitian = self.hermitian
return temp_d2o
def copy_empty(self, global_shape=None, local_shape=None, dtype=None,
def copy_empty(self, global_shape=None, local_shape=None, dtype=None,
distribution_strategy=None, **kwargs):
if self.distribution_strategy == 'not' and \
distribution_strategy in LOCAL_DISTRIBUTION_STRATEGIES and \
......@@ -250,7 +167,7 @@ class distributed_data_object(object):
distribution_strategy = 'equal',
**kwargs)
return result.copy_empty(distribution_strategy = 'freeform')
if global_shape is None:
global_shape = self.shape
if local_shape is None:
......@@ -261,7 +178,7 @@ class distributed_data_object(object):
distribution_strategy = self.distribution_strategy
kwargs.update(self.init_kwargs)
temp_d2o = distributed_data_object(global_shape=global_shape,
local_shape = local_shape,
dtype = dtype,
......@@ -270,81 +187,81 @@ class distributed_data_object(object):
*self.init_args,
**kwargs)
return temp_d2o
def apply_scalar_function(self, function, inplace=False, dtype=None):
remember_hermitianQ = self.hermitian
if inplace == True:
if inplace == True:
temp = self
if dtype is not None and self.dtype != np.dtype(dtype):
about.warnings.cprint(\
"WARNING: Inplace dtype conversion is not possible!")
about.warnings.cprint(
"WARNING: Inplace dtype conversion is not possible!")
else:
temp = self.copy_empty(dtype=dtype)
if np.prod(self.local_shape) != 0:
try:
try:
temp.data[:] = function(self.data)
except:
temp.data[:] = np.vectorize(function)(self.data)
else:
## Noting to do here. The value-empty array
## is also geometrically empty
# Noting to do here. The value-empty array
# is also geometrically empty
pass
if function in (np.exp, np.log):
temp.hermitian = remember_hermitianQ
else:
temp.hermitian = False
return temp
def apply_generator(self, generator):
self.set_local_data(generator(self.distributor.local_shape))
self.hermitian = False
def __str__(self):
return self.data.__str__()
def __repr__(self):
return '<distributed_data_object>\n'+self.data.__repr__()
def _compare_helper(self, other, op):
result = self.copy_empty(dtype = np.bool_)
## Case 1: 'other' is a scalar
## -> make point-wise comparison
# Case 1: 'other' is a scalar
# -> make point-wise comparison
if np.isscalar(other):
result.set_local_data(
getattr(self.get_local_data(copy = False), op)(other))
return result
return result
## Case 2: 'other' is a numpy array or a distributed_data_object
## -> extract the local data and make point-wise comparison
# Case 2: 'other' is a numpy array or a distributed_data_object
# -> extract the local data and make point-wise comparison
elif isinstance(other, np.ndarray) or\
isinstance(other, distributed_data_object):
temp_data = self.distributor.extract_local_data(other)
result.set_local_data(
getattr(self.get_local_data(copy=False), op)(temp_data))
return result
## Case 3: 'other' is None
# Case 3: 'other' is None
elif other is None:
return False
## Case 4: 'other' is something different
## -> make a numpy casting and make a recursive call
# Case 4: 'other' is something different
# -> make a numpy casting and make a recursive call
else:
temp_other = np.array(other)
return getattr(self, op)(temp_other)
def __ne__(self, other):
return self._compare_helper(other, '__ne__')
def __lt__(self, other):
return self._compare_helper(other, '__lt__')
def __le__(self, other):
return self._compare_helper(other, '__le__')
......@@ -371,23 +288,23 @@ class distributed_data_object(object):
return False
else:
return True
def __pos__(self):
temp_d2o = self.copy_empty()
temp_d2o.set_local_data(data = self.get_local_data(), copy = True)
return temp_d2o
def __neg__(self):
temp_d2o = self.copy_empty()
temp_d2o.set_local_data(data = self.get_local_data().__neg__(),
copy = True)
copy = True)
return temp_d2o
def __abs__(self):
## translate complex dtypes
# translate complex dtypes
if self.dtype == np.dtype('complex64'):
new_dtype = np.dtype('float32')
elif self.dtype == np.dtype('complex128'):
......@@ -398,42 +315,39 @@ class distributed_data_object(object):
new_dtype = self.dtype
temp_d2o = self.copy_empty(dtype = new_dtype)
temp_d2o.set_local_data(data = self.get_local_data().__abs__(),
copy = True)
copy = True)
return temp_d2o
def _builtin_helper(self, operator, other, inplace=False):
if isinstance(other, distributed_data_object):
other_is_real = other.isreal()
else:
other_is_real = np.isreal(other)
## Case 1: other is not a scalar
# Case 1: other is not a scalar
if not (np.isscalar(other) or np.shape(other) == (1,)):
## if self.shape != other.shape:
## raise AttributeError(about._errors.cstring(
## "ERROR: Shapes do not match!"))
try:
try:
hermitian_Q = (other.hermitian and self.hermitian)
except(AttributeError):
hermitian_Q = False
## extract the local data from the 'other' object
# extract the local data from the 'other' object
temp_data = self.distributor.extract_local_data(other)
temp_data = operator(temp_data)
## Case 2: other is a real scalar -> preserve hermitianity
# Case 2: other is a real scalar -> preserve hermitianity
elif other_is_real or (self.dtype not in (np.dtype('complex128'),
np.dtype('complex256'))):
hermitian_Q = self.hermitian
temp_data = operator(other)
## Case 3: other is complex
# Case 3: other is complex
else:
hermitian_Q = False
temp_data = operator(other)
## write the new data into a new distributed_data_object
temp_data = operator(other)
# write the new data into a new distributed_data_object
if inplace == True:
temp_d2o = self
else:
## use common datatype for self and other
# use common datatype for self and other
new_dtype = np.dtype(np.find_common_type((self.dtype,),
(temp_data.dtype,)))
temp_d2o = self.copy_empty(
......@@ -441,24 +355,7 @@ class distributed_data_object(object):
temp_d2o.set_local_data(data=temp_data)
temp_d2o.hermitian = hermitian_Q
return temp_d2o
"""
def __inplace_builtin_helper__(self, operator, other):
## Case 1: other is not a scalar
if not (np.isscalar(other) or np.shape(other) == (1,)):
temp_data = self.distributor.extract_local_data(other)
temp_data = operator(temp_data)
## Case 2: other is a real scalar -> preserve hermitianity
elif np.isreal(other):
hermitian_Q = self.hermitian
temp_data = operator(other)
## Case 3: other is complex
else:
temp_data = operator(other)
self.set_local_data(data=temp_data)
self.hermitian = hermitian_Q
return self
"""
def __add__(self, other):
return self._builtin_helper(self.get_local_data().__add__, other)
......@@ -466,124 +363,102 @@ class distributed_data_object(object):
return self._builtin_helper(self.get_local_data().__radd__, other)
def __iadd__(self, other):
return self._builtin_helper(self.get_local_data().__iadd__,
return self._builtin_helper(self.get_local_data().__iadd__,
other,
inplace = True)
def __sub__(self, other):
return self._builtin_helper(self.get_local_data().__sub__, other)
def __rsub__(self, other):
return self._builtin_helper(self.get_local_data().__rsub__, other)
def __isub__(self, other):
return self._builtin_helper(self.get_local_data().__isub__,
return self._builtin_helper(self.get_local_data().__isub__,
other,
inplace = True)
def __div__(self, other):
return self._builtin_helper(self.get_local_data().__div__, other)
def __truediv__(self, other):
return self.__div__(other)
def __rdiv__(self, other):
return self._builtin_helper(self.get_local_data().__rdiv__, other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __idiv__(self, other):
return self._builtin_helper(self.get_local_data().__idiv__,
return self._builtin_helper(self.get_local_data().__idiv__,
other,
inplace = True)
def __itruediv__(self, other):
return self.__idiv__(other)
def __floordiv__(self, other):
return self._builtin_helper(self.get_local_data().__floordiv__,
other)
return self._builtin_helper(self.get_local_data().__floordiv__,
other)
def __rfloordiv__(self, other):
return self._builtin_helper(self.get_local_data().__rfloordiv__,
return self._builtin_helper(self.get_local_data().__rfloordiv__,
other)
def __ifloordiv__(self, other):
return self._builtin_helper(
self.get_local_data().__ifloordiv__, other,
inplace = True)
def __mul__(self, other):
return self._builtin_helper(self.get_local_data().__mul__, other)
def __rmul__(self, other):
return self._builtin_helper(self.get_local_data().__rmul__, other)
def __imul__(self, other):
return self._builtin_helper(self.get_local_data().__imul__,
return self._builtin_helper(self.get_local_data().__imul__,
other,
inplace = True)
def __pow__(self, other):
return self._builtin_helper(self.get_local_data().__pow__, other)
def __rpow__(self, other):
return self._builtin_helper(self.get_local_data().__rpow__, other)
def __ipow__(self, other):
return self._builtin_helper(self.get_local_data().__ipow__,
return self._builtin_helper(self.get_local_data().__ipow__,
other,
inplace = True)
def __mod__(self, other):
return self._builtin_helper(self.get_local_data().__mod__, other)
def __rmod__(self, other):
return self._builtin_helper(self.get_local_data().__rmod__, other)
return self._builtin_helper(self.get_local_data().__rmod__, other)
def __imod__(self, other):
return self._builtin_helper(self.get_local_data().__imod__,
return self._builtin_helper(self.get_local_data().__imod__,
other,
inplace = True)
inplace = True)
def __len__(self):
return self.shape[0]
def get_dim(self):
return np.prod(self.shape)
def vdot(self, other):
other = self.distributor.extract_local_data(other)
local_vdot = np.vdot(self.get_local_data(), other)
local_vdot_list = self.distributor._allgather(local_vdot)
global_vdot = np.sum(local_vdot_list)
return global_vdot
def __getitem__(self, key):
return self.get_data(key)
# ## Case 1: key is a boolean array.
# ## -> take the local data portion from key, use this for data
# ## extraction, and then merge the result in a flat numpy array
# if isinstance(key, np.ndarray):
# found = 'ndarray'
# found_boolean = (key.dtype.type == np.bool_)
# elif isinstance(key, distributed_data_object):
# found = 'd2o'
# found_boolean = (key.dtype == np.bool_)
# else:
# found = 'other'
# ## TODO: transfer this into distributor:
# if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
# ## extract the data of local relevance
# local_bool_array = self.distributor.extract_local_data(key)
# local_results = self.get_local_data(copy=False)[local_bool_array]
# global_results = self.distributor._allgather(local_results)
# global_results = np.concatenate(global_results)
# return global_results
#
# else:
# return self.get_data(key)
def __setitem__(self, key, data):
self.set_data(data, key)
def _contraction_helper(self, function, **kwargs):
if np.prod(self.data.shape) == 0:
local = 0
......@@ -599,55 +474,55 @@ class distributed_data_object(object):
if work_list.shape[0] == 0:
raise ValueError("ERROR: Zero-size array to reduction op