Commit bd1f7bc4 authored by csongor's avatar csongor

WIP: remove comm from spaces.

parent bc3f10ca
Pipeline #3857 skipped
......@@ -120,8 +120,7 @@ class lm_space(point_space):
Pixel volume of the :py:class:`lm_space`, which is always 1.
"""
def __init__(self, lmax, mmax=None, dtype=np.dtype('complex128'),
comm=gc['default_comm']):
def __init__(self, lmax, mmax=None, dtype=np.dtype('complex128')):
"""
Sets the attributes for an lm_space class instance.
......@@ -169,12 +168,10 @@ class lm_space(point_space):
self.discrete = True
self.harmonic = True
self.distances = (np.float(1),)
self.comm = self._parse_comm(comm)
self.power_indices = lm_power_indices(
lmax=self.paradict['lmax'],
dim=self.get_dim(),
comm=self.comm,
allowed_distribution_strategies=LM_DISTRIBUTION_STRATEGIES)
@property
......@@ -202,8 +199,7 @@ class lm_space(point_space):
((lambda x: tuple(x) if
isinstance(x, np.ndarray) else x)(ii[1])))
for ii in vars(self).iteritems()
if ii[0] not in ['_cache_dict', 'power_indices', 'comm']]
temp.append(('comm', self.comm.__hash__()))
if ii[0] not in ['_cache_dict', 'power_indices']]
# Return the sorted identifiers as a tuple.
return tuple(sorted(temp))
......@@ -323,9 +319,6 @@ class lm_space(point_space):
raise TypeError(about._errors.cstring(
"ERROR: The given codomain must be a nifty lm_space."))
if self.comm is not codomain.comm:
return False
elif isinstance(codomain, gl_space):
# lmax==mmax
# nlat==lmax+1
......@@ -385,13 +378,11 @@ class lm_space(point_space):
raise NotImplementedError
nlat = self.paradict['lmax'] + 1
nlon = self.paradict['lmax'] * 2 + 1
return gl_space(nlat=nlat, nlon=nlon, dtype=new_dtype,
comm=self.comm)
return gl_space(nlat=nlat, nlon=nlon, dtype=new_dtype)
elif coname == 'hp' or (coname is None and not gc['lm2gl']):
nside = (self.paradict['lmax']+1) // 3
return hp_space(nside=nside,
comm=self.comm)
return hp_space(nside=nside)
else:
raise ValueError(about._errors.cstring(
......@@ -947,8 +938,7 @@ class gl_space(point_space):
An array containing the pixel sizes.
"""
def __init__(self, nlat, nlon=None, dtype=np.dtype('float64'),
comm=gc['default_comm']):
def __init__(self, nlat, nlon=None, dtype=np.dtype('float64')):
"""
Sets the attributes for a gl_space class instance.
......@@ -993,7 +983,6 @@ class gl_space(point_space):
self.distances = tuple(gl.vol(self.paradict['nlat'],
nlon=self.paradict['nlon']
).astype(np.float))
self.comm = self._parse_comm(comm)
@property
def para(self):
......@@ -1099,9 +1088,6 @@ class gl_space(point_space):
if not isinstance(codomain, space):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
if self.comm is not codomain.comm:
return False
if isinstance(codomain, lm_space):
nlat = self.paradict['nlat']
nlon = self.paradict['nlon']
......@@ -1130,11 +1116,9 @@ class gl_space(point_space):
mmax = nlat-1
# lmax,mmax = nlat-1,nlat-1
if self.dtype == np.dtype('float32'):
return lm_space(lmax=lmax, mmax=mmax, dtype=np.complex64,
comm=self.comm)
return lm_space(lmax=lmax, mmax=mmax, dtype=np.complex64)
else:
return lm_space(lmax=lmax, mmax=mmax, dtype=np.complex128,
comm=self.comm)
return lm_space(lmax=lmax, mmax=mmax, dtype=np.complex128)
def get_random_values(self, **kwargs):
"""
......@@ -1626,7 +1610,7 @@ class hp_space(point_space):
An array with one element containing the pixel size.
"""
def __init__(self, nside, comm=gc['default_comm']):
def __init__(self, nside):
"""
Sets the attributes for a hp_space class instance.
......@@ -1662,7 +1646,6 @@ class hp_space(point_space):
self.discrete = False
self.harmonic = False
self.distances = (np.float(4*np.pi / (12*self.paradict['nside']**2)),)
self.comm = self._parse_comm(comm)
@property
def para(self):
......@@ -1763,9 +1746,6 @@ class hp_space(point_space):
if not isinstance(codomain, space):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
if self.comm is not codomain.comm:
return False
if isinstance(codomain, lm_space):
nside = self.paradict['nside']
lmax = codomain.paradict['lmax']
......@@ -1789,8 +1769,7 @@ class hp_space(point_space):
"""
lmax = 3*self.paradict['nside'] - 1
mmax = lmax
return lm_space(lmax=lmax, mmax=mmax, dtype=np.dtype('complex128'),
comm=self.comm)
return lm_space(lmax=lmax, mmax=mmax, dtype=np.dtype('complex128'))
def get_random_values(self, **kwargs):
"""
......
......@@ -785,8 +785,7 @@ class point_space(space):
Pixel volume of the :py:class:`point_space`, which is always 1.
"""
def __init__(self, num, dtype=np.dtype('float'),
comm=gc['default_comm']):
def __init__(self, num, dtype=np.dtype('float')):
"""
Sets the attributes for a point_space class instance.
......@@ -818,7 +817,6 @@ class point_space(space):
"WARNING: incompatible dtype: " + str(dtype)))
self.dtype = dtype
self.comm = self._parse_comm(comm)
self.discrete = True
# self.harmonic = False
self.distances = (np.float(1),)
......@@ -851,29 +849,9 @@ class point_space(space):
# Return the sorted identifiers as a tuple.
return tuple(sorted(temp))
def _parse_comm(self, comm):
# check if comm is a string -> the name of comm is given
# -> Extract it from the mpi_module
if isinstance(comm, str):
if gc.validQ('default_comm', comm):
result_comm = getattr(gdi[gc['mpi_module']], comm)
else:
raise ValueError(about._errors.cstring(
"ERROR: The given communicator-name is not supported."))
# check if the given comm object is an instance of default Intracomm
else:
if isinstance(comm, gdi[gc['mpi_module']].Intracomm):
result_comm = comm
else:
raise ValueError(about._errors.cstring(
"ERROR: The given comm object is not an instance of the " +
"default-MPI-module's Intracomm Class."))
return result_comm
def copy(self):
return point_space(num=self.paradict['num'],
dtype=self.dtype,
comm=self.comm)
dtype=self.dtype)
def getitem(self, data, key):
return data[key]
......@@ -1113,8 +1091,7 @@ class point_space(space):
# Case 2: x is something else
# Use general d2o casting
else:
x = distributed_data_object(x,
global_shape=self.get_shape(),
x = distributed_data_object(x, global_shape=self.get_shape(),
dtype=dtype)
# Cast the d2o
return self.cast(x, dtype=dtype)
......@@ -1617,7 +1594,6 @@ class point_space(space):
string += str(type(self)) + "\n"
string += "paradict: " + str(self.paradict) + "\n"
string += 'dtype: ' + str(self.dtype) + "\n"
string += 'comm: ' + self.comm.name + "\n"
string += 'discrete: ' + str(self.discrete) + "\n"
string += 'distances: ' + str(self.distances) + "\n"
return string
......@@ -1709,8 +1685,8 @@ class field(object):
"""
def __init__(self, domain=None, val=None, codomain=None,
copy=False, dtype=np.dtype('float64'), datamodel='not',
def __init__(self, domain=None, val=None, codomain=None, comm=gc[
'default_comm'], copy=False, dtype=np.dtype('float64'), datamodel='not',
**kwargs):
"""
Sets the attributes for a field class instance.
......@@ -1740,6 +1716,7 @@ class field(object):
self._init_from_field(f=val,
domain=domain,
codomain=codomain,
comm=comm,
copy=copy,
dtype=dtype,
datamodel=datamodel,
......@@ -1748,12 +1725,14 @@ class field(object):
self._init_from_array(val=val,
domain=domain,
codomain=codomain,
comm=comm,
copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs)
def _init_from_field(self, f, domain, codomain, copy, dtype, datamodel,
def _init_from_field(self, f, domain, codomain, comm, copy, dtype,
datamodel,
**kwargs):
# check domain
if domain is None:
......@@ -1776,16 +1755,18 @@ class field(object):
self._init_from_array(domain=domain,
val=f.val,
codomain=codomain,
comm=comm,
copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs)
def _init_from_array(self, val, domain, codomain, copy, dtype, datamodel,
**kwargs):
def _init_from_array(self, val, domain, codomain, comm, copy, dtype,
datamodel, **kwargs):
if dtype is None:
dtype = np.dtype('float64')
self.dtype = dtype
self.comm = self._parse_comm(comm)
if datamodel not in DISTRIBUTION_STRATEGIES['global']:
about.warnings.cprint("WARNING: datamodel set to default.")
......@@ -1813,6 +1794,25 @@ class field(object):
**kwargs))
self.set_val(new_val=val, copy=copy)
def _parse_comm(self, comm):
# check if comm is a string -> the name of comm is given
# -> Extract it from the mpi_module
if isinstance(comm, str):
if gc.validQ('default_comm', comm):
result_comm = getattr(gdi[gc['mpi_module']], comm)
else:
raise ValueError(about._errors.cstring(
"ERROR: The given communicator-name is not supported."))
# check if the given comm object is an instance of default Intracomm
else:
if isinstance(comm, gdi[gc['mpi_module']].Intracomm):
result_comm = comm
else:
raise ValueError(about._errors.cstring(
"ERROR: The given comm object is not an instance of the " +
"default-MPI-module's Intracomm Class."))
return result_comm
def check_valid_domain(self, domain):
if not isinstance(domain, np.ndarray):
raise TypeError(about._errors.cstring(
......@@ -1886,22 +1886,29 @@ class field(object):
self.domain.unary_operation(self.val, op='copy_empty')
return new_field
def copy_empty(self, domain=None, codomain=None, ishape=None, **kwargs):
def copy_empty(self, domain=None, codomain=None, dtype=None, comm=None,
datamodel=None, **kwargs):
if domain is None:
domain = self.domain
if codomain is None:
codomain = self.codomain
if ishape is None:
ishape = self.ishape
if dtype is None:
dtype = self.dtype
if comm is None:
comm = self.comm
if datamodel is None:
datamodel = self.datamodel
if (domain is self.domain and
codomain is self.codomain and
ishape == self.ishape and
dtype == self.dtype and
comm == self.comm and
datamodel == self.datamodel and
kwargs == {}):
new_field = self._fast_copy_empty()
else:
new_field = field(domain=domain, codomain=codomain, ishape=ishape,
**kwargs)
new_field = field(domain=domain, codomain=codomain, dtype=dtype,
comm=comm, datamodel=datamodel, **kwargs)
return new_field
def set_val(self, new_val=None, copy=False):
......
......@@ -122,8 +122,7 @@ class rg_space(point_space):
epsilon = 0.0001 # relative precision for comparisons
def __init__(self, shape, zerocenter=False, complexity=0, distances=None,
harmonic=False, fft_module=gc['fft_module'],
comm=gc['default_comm']):
harmonic=False, fft_module=gc['fft_module']):
"""
Sets the attributes for an rg_space class instance.
......@@ -185,8 +184,6 @@ class rg_space(point_space):
self.harmonic = bool(harmonic)
self.discrete = False
self.comm = self._parse_comm(comm)
# Initializes the fast-fourier-transform machine, which will be used
# to transform the space
if not gc.validQ('fft_module', fft_module):
......@@ -201,7 +198,6 @@ class rg_space(point_space):
shape=self.get_shape(),
dgrid=distances,
zerocentered=self.paradict['zerocenter'],
comm=self.comm,
allowed_distribution_strategies=RG_DISTRIBUTION_STRATEGIES)
@property
......@@ -237,8 +233,7 @@ class rg_space(point_space):
isinstance(x, np.ndarray) else x)(ii[1])))
for ii in vars(self).iteritems()
if ii[0] not in ['_cache_dict', 'fft_machine',
'power_indices', 'comm']]
temp.append(('comm', self.comm.__hash__()))
'power_indices']]
# Return the sorted identifiers as a tuple.
return tuple(sorted(temp))
......@@ -248,8 +243,7 @@ class rg_space(point_space):
zerocenter=self.paradict['zerocenter'],
distances=self.distances,
harmonic=self.harmonic,
fft_module=self.fft_machine.name,
comm=self.comm)
fft_module=self.fft_machine.name)
def get_shape(self):
return tuple(self.paradict['shape'])
......@@ -366,9 +360,6 @@ class rg_space(point_space):
raise TypeError(about._errors.cstring(
"ERROR: The given codomain must be a nifty rg_space."))
if self.comm is not codomain.comm:
return False
# check number of number and size of axes
if not np.all(np.array(self.paradict['shape']) ==
np.array(codomain.paradict['shape'])):
......@@ -470,7 +461,6 @@ class rg_space(point_space):
distances = 1 / (np.array(self.paradict['shape']) *
np.array(self.distances))
fft_module = self.fft_machine.name
comm = self.comm
complexity = {0: 1, 1: 0, 2: 2}[self.paradict['complexity']]
harmonic = bool(not self.harmonic)
......@@ -479,8 +469,7 @@ class rg_space(point_space):
complexity=complexity,
distances=distances,
harmonic=harmonic,
fft_module=fft_module,
comm=comm)
fft_module=fft_module)
return new_space
def get_random_values(self, **kwargs):
......
......@@ -488,7 +488,7 @@ class Test_Point_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(
itertools.product(all_point_datatypes,
itertools.product(all_point_datatypes),
testcase_func_name=custom_name_func)
def test_cast_from_ndarray(self, dtype):
num = 10
......@@ -699,8 +699,7 @@ class Test_RG_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(
testcase_func_name=custom_name_func)
@parameterized.expand([], testcase_func_name=custom_name_func)
def test_cast_to_hermitian(self):
shape = (10, 10)
x = rg_space(shape, complexity=1)
......@@ -711,8 +710,7 @@ class Test_RG_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(
testcase_func_name=custom_name_func)
@parameterized.expand([], testcase_func_name=custom_name_func)
def test_enforce_power(self):
shape = (6, 6)
x = rg_space(shape)
......@@ -773,8 +771,7 @@ class Test_RG_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(
testcase_func_name=custom_name_func)
@parameterized.expand([], testcase_func_name=custom_name_func)
def test_calc_dot(self):
shape = (8, 8)
a = np.arange(np.prod(shape)).reshape(shape)
......@@ -932,7 +929,7 @@ class Test_RG_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(testcase_func_name=custom_name_func)
@parameterized.expand([],testcase_func_name=custom_name_func)
def test_calc_smooth(self):
sigma = 0.01
shape = (8, 8)
......@@ -959,7 +956,7 @@ class Test_RG_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(testcase_func_name=custom_name_func)
@parameterized.expand([], testcase_func_name=custom_name_func)
def test_calc_power(self):
shape = (8, 8)
a = np.arange(np.prod(shape)).reshape(shape)
......@@ -1034,8 +1031,7 @@ class Test_Lm_Space(unittest.TestCase):
###############################################################################
@parameterized.expand(
testcase_func_name=custom_name_func)
@parameterized.expand([], testcase_func_name=custom_name_func)
def test_enforce_power(self):
lmax = 17
mmax = 12
......@@ -1058,7 +1054,7 @@ class Test_Lm_Space(unittest.TestCase):
##############################################################################
@parameterized.expand(testcase_func_name=custom_name_func)
@parameterized.expand([], testcase_func_name=custom_name_func)
def test_get_check_codomain(self):
lmax = 23
mmax = 23
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment