Commit 124e6763 authored by Ultima's avatar Ultima

Several performance improvements.

parent ca85ae75
......@@ -148,10 +148,3 @@ class _comm_wrapper(object):
COMM_WORLD = _comm_wrapper('MPI_dummy_COMM_WORLD')
......@@ -154,6 +154,9 @@ class lm_space(point_space):
raise ImportError(about._errors.cstring(
"ERROR: neither libsharp_wrapper_gl nor healpy activated."))
self._cache_dict = {'check_codomain': {}}
self.paradict = lm_space_paradict(lmax=lmax, mmax=mmax)
# check data type
......@@ -196,7 +199,7 @@ class lm_space(point_space):
def __hash__(self):
result_hash = 0
for (key, item) in vars(self).items():
if key in ['power_indices']:
if key in ['_cache_dict', 'power_indices']:
continue
result_hash ^= item.__hash__() * hash(key)
return result_hash
......@@ -207,7 +210,7 @@ class lm_space(point_space):
((lambda x: tuple(x) if
isinstance(x, np.ndarray) else x)(ii[1])))
for ii in vars(self).iteritems()
if ii[0] not in ['power_indices', 'comm']]
if ii[0] not in ['_cache_dict', 'power_indices', 'comm']]
temp.append(('comm', self.comm.__hash__()))
# Return the sorted identifiers as a tuple.
return tuple(sorted(temp))
......@@ -303,7 +306,7 @@ class lm_space(point_space):
size=size,
kindex=kindex)
def check_codomain(self, codomain):
def _check_codomain(self, codomain):
"""
Checks whether a given codomain is compatible to the
:py:class:`lm_space` or not.
......@@ -938,6 +941,7 @@ class gl_space(point_space):
raise ImportError(about._errors.cstring(
"ERROR: libsharp_wrapper_gl not loaded."))
self._cache_dict = {'check_codomain': {}}
self.paradict = gl_space_paradict(nlat=nlat, nlon=nlon)
# check data type
......@@ -1040,7 +1044,7 @@ class gl_space(point_space):
size=size,
kindex=kindex)
def check_codomain(self, codomain):
def _check_codomain(self, codomain):
"""
Checks whether a given codomain is compatible to the space or not.
......@@ -1577,6 +1581,7 @@ class hp_space(point_space):
raise ImportError(about._errors.cstring(
"ERROR: healpy not available."))
self._cache_dict = {'check_codomain': {}}
# check parameters
self.paradict = hp_space_paradict(nside=nside)
......@@ -1668,7 +1673,7 @@ class hp_space(point_space):
size=size,
kindex=kindex)
def check_codomain(self, codomain):
def _check_codomain(self, codomain):
"""
Checks whether a given codomain is compatible to the space or not.
......
......@@ -801,6 +801,7 @@ class point_space(space):
-------
None.
"""
self._cache_dict = {'check_codomain': {}}
self.paradict = point_space_paradict(num=num)
# parse dtype
......@@ -842,6 +843,8 @@ class point_space(space):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for (key, item) in vars(self).items():
if key in ['_cache_dict']:
continue
result_hash ^= item.__hash__() * hash(key)
return result_hash
......@@ -850,6 +853,7 @@ class point_space(space):
temp = [(ii[0],
((lambda x: x[1].__hash__() if x[0] == 'comm' else x)(ii)))
for ii in vars(self).iteritems()
if ii[0] not in ['_cache_dict']
]
# Return the sorted identifiers as a tuple.
return tuple(sorted(temp))
......@@ -1369,6 +1373,16 @@ class point_space(space):
return spec
def check_codomain(self, codomain):
check_dict = self._cache_dict['check_codomain']
temp_id = id(codomain)
if temp_id in check_dict:
return check_dict[temp_id]
else:
temp_result = self._check_codomain(codomain)
check_dict[temp_id] = temp_result
return temp_result
def _check_codomain(self, codomain):
"""
Checks whether a given codomain is compatible to the space or not.
......
......@@ -41,6 +41,10 @@ STRATEGIES = {
'not': ['not'],
'hdf5': ['equal'] + _maybe_fftw,
}
if _maybe_fftw != []:
_default_strategy = 'fftw'
else:
_default_strategy = 'equal'
class distributed_data_object(object):
......@@ -89,7 +93,7 @@ class distributed_data_object(object):
"""
def __init__(self, global_data=None, global_shape=None, dtype=None,
local_data=None, local_shape=None,
distribution_strategy='fftw', hermitian=False,
distribution_strategy=_default_strategy, hermitian=False,
alias=None, path=None, comm=MPI.COMM_WORLD,
copy=True, *args, **kwargs):
......@@ -737,8 +741,10 @@ class distributed_data_object(object):
if self.distribution_strategy == 'not':
return local_counts
else:
list_of_counts = self.distributor._allgather(local_counts)
counts = np.sum(list_of_counts, axis=0)
counts = np.empty_like(local_counts)
self.distributor._Allreduce_sum(local_counts, counts)
# list_of_counts = self.distributor._allgather(local_counts)
# counts = np.sum(list_of_counts, axis=0)
return counts
def where(self):
......@@ -1463,6 +1469,14 @@ class _slicing_distributor(distributor):
gathered_things = comm.allgather(thing)
return gathered_things
def _Allreduce_sum(self, sendbuf, recvbuf):
send_dtype = self._my_dtype_converter.to_mpi(sendbuf.dtype)
recv_dtype = self._my_dtype_converter.to_mpi(recvbuf.dtype)
self.comm.Allreduce([sendbuf, send_dtype],
[recvbuf, recv_dtype],
op=MPI.SUM)
return recvbuf
def distribute_data(self, data=None, alias=None,
path=None, copy=True, **kwargs):
'''
......@@ -2137,13 +2151,20 @@ class _slicing_distributor(distributor):
# Case 2: First dimension fits directly and data_object is a d2o
elif isinstance(data_object, distributed_data_object):
# Check if the distributor and the comm match
# the own ones. Checking equality via 'is' is ok, as the
# distributor factory caches simmilar distributors
if self is data_object.distributor and\
self.comm is data_object.distributor.comm:
# Case 1: yes. Simply take the local data
# Check if both d2os have the same slicing
# If the distributor is exactly the same, extract the data
if self is data_object.distributor:
# Simply take the local data
extracted_data = data_object.data
# If the distributor is not exactly the same, check if the
# geometry matches if it is a slicing distributor
# -> comm and local shapes
elif isinstance(data_object.distributor, _slicing_distributor):
if (self.comm is data_object.distributor.comm) and \
np.all(self.all_local_slices ==
data_object.distributor.all_local_slices):
extracted_data = data_object.data
else:
# Case 2: no. All nodes extract their local slice from the
# data_object
......@@ -2153,6 +2174,26 @@ class _slicing_distributor(distributor):
local_keys=True)
extracted_data = extracted_data.get_local_data()
# # Check if the distributor and the comm match
# # the own ones. Checking equality via 'is' is ok, as the
# # distributor factory caches simmilar distributors
# if self is data_object.distributor and\
# self.comm is data_object.distributor.comm:
# # Case 1: yes. Simply take the local data
# extracted_data = data_object.data
# # If the distributors do not match directly, check
# else:
# # Case 2: no. All nodes extract their local slice from the
# # data_object
# extracted_data =\
# data_object.get_data(slice(self.local_start,
# self.local_end),
# local_keys=True)
# extracted_data = extracted_data.get_local_data()
#
## print ('boo', data_object.distribution_strategy)
# Case 3: First dimension fits directly and data_object is an
# generic array
else:
......@@ -2604,6 +2645,10 @@ class _not_distributor(distributor):
def _allgather(self, thing):
return [thing, ]
def _Allreduce_sum(self, sendbuf, recvbuf):
recvbuf[:] = sendbuf
return recvbuf
def distribute_data(self, data, alias=None, path=None, copy=True,
**kwargs):
if 'h5py' in gdi and alias is not None:
......
......@@ -153,7 +153,7 @@ class rg_space(point_space):
-------
None
"""
self._cache_dict = {'check_codomain':{}}
self.paradict = rg_space_paradict(shape=shape,
complexity=complexity,
zerocenter=zerocenter)
......@@ -229,7 +229,7 @@ class rg_space(point_space):
def __hash__(self):
result_hash = 0
for (key, item) in vars(self).items():
if key in ['fft_machine', 'power_indices']:
if key in ['_cache_dict', 'fft_machine', 'power_indices']:
continue
result_hash ^= item.__hash__() * hash(key)
return result_hash
......@@ -245,7 +245,8 @@ class rg_space(point_space):
((lambda x: tuple(x) if
isinstance(x, np.ndarray) else x)(ii[1])))
for ii in vars(self).iteritems()
if ii[0] not in ['fft_machine', 'power_indices', 'comm']]
if ii[0] not in ['_cache_dict', 'fft_machine',
'power_indices', 'comm']]
temp.append(('comm', self.comm.__hash__()))
# Return the sorted identifiers as a tuple.
return tuple(sorted(temp))
......@@ -353,7 +354,7 @@ class rg_space(point_space):
size=size,
kindex=kindex)
def check_codomain(self, codomain):
def _check_codomain(self, codomain):
"""
Checks whether a given codomain is compatible to the space or not.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment