Commit 23e0643e authored by Ultima's avatar Ultima
Browse files

Improved the multidimensionality of fields; a lot.

Added some ufuncs to point_space, field and distributed_data_object.
Updated the projection_operator.
parent 0f1e4c12
......@@ -171,6 +171,13 @@ class lm_space(point_space):
datatype = np.complex128
self.datatype = datatype
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = True
self.vol = np.real(np.array([1],dtype=self.datatype))
......@@ -1008,7 +1015,7 @@ class gl_space(point_space):
vol : numpy.ndarray
An array containing the pixel sizes.
"""
def __init__(self,nlat,nlon=None,datatype=None):
def __init__(self, nlat, nlon=None, datatype=None, datamodel='np'):
"""
Sets the attributes for a gl_space class instance.
......@@ -1048,6 +1055,13 @@ class gl_space(point_space):
datatype = np.float64
self.datatype = datatype
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = False
self.vol = gl.vol(self.paradict['nlat'],nlon=self.paradict['nlon']).astype(self.datatype)
......@@ -1702,7 +1716,7 @@ class hp_space(point_space):
"""
niter = 0 ## default number of iterations used for transformations
def __init__(self, nside):
def __init__(self, nside, datamodel = 'np'):
"""
Sets the attributes for a hp_space class instance.
......@@ -1731,6 +1745,14 @@ class hp_space(point_space):
self.paradict = hp_space_paradict(nside=nside)
self.datatype = np.float64
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = False
self.vol = np.array([4*pi/(12*self.paradict['nside']**2)],dtype=self.datatype)
......
This diff is collapsed.
......@@ -132,7 +132,7 @@ class distributed_data_object(object):
except(AttributeError):
dtype = np.array(global_data).dtype.type
else:
dtype = dtype
dtype = np.dtype(dtype).type
## an explicitly given global_shape argument is only used if
## 1. no global_data was supplied, or
......@@ -171,7 +171,7 @@ class distributed_data_object(object):
## If the input data was a scalar, set the whole array to this value
elif global_data != None and np.isscalar(global_data):
temp = np.empty(self.distributor.local_shape)
temp = np.empty(self.distributor.local_shape, dtype = self.dtype)
temp.fill(global_data)
self.set_local_data(temp)
self.hermitian = True
......@@ -243,12 +243,14 @@ class distributed_data_object(object):
def __repr__(self):
return '<distributed_data_object>\n'+self.data.__repr__()
def __eq__(self, other):
def _compare_helper(self, other, op):
result = self.copy_empty(dtype = np.bool_)
## Case 1: 'other' is a scalar
## -> make point-wise comparison
if np.isscalar(other):
result.set_local_data(self.get_local_data(copy = False) == other)
result.set_local_data(
getattr(self.get_local_data(copy = False), op)(other))
return result
## Case 2: 'other' is a numpy array or a distributed_data_object
......@@ -256,7 +258,8 @@ class distributed_data_object(object):
elif isinstance(other, np.ndarray) or\
isinstance(other, distributed_data_object):
temp_data = self.distributor.extract_local_data(other)
result.set_local_data(self.get_local_data(copy=False) == temp_data)
result.set_local_data(
getattr(self.get_local_data(copy=False), op)(temp_data))
return result
## Case 3: 'other' is None
......@@ -267,10 +270,26 @@ class distributed_data_object(object):
## -> make a numpy casting and make a recursive call
else:
temp_other = np.array(other)
return self.__eq__(temp_other)
return getattr(self, op)(temp_other)
def __ne__(self, other):
return self._compare_helper(other, '__ne__')
def __lt__(self, other):
return self._compare_helper(other, '__lt__')
def __le__(self, other):
return self._compare_helper(other, '__le__')
def __eq__(self, other):
return self._compare_helper(other, '__eq__')
def __ge__(self, other):
return self._compare_helper(other, '__ge__')
def __gt__(self, other):
return self._compare_helper(other, '__gt__')
def equal(self, other):
if other is None:
......@@ -448,7 +467,7 @@ class distributed_data_object(object):
def __len__(self):
return self.shape[0]
def dim(self):
def get_dim(self):
return np.prod(self.shape)
def vdot(self, other):
......@@ -620,6 +639,33 @@ class distributed_data_object(object):
global_any = self.distributor._allgather(local_any)
return np.any(global_any)
def unique(self):
local_unique = np.unique(self.get_local_data())
global_unique = self.distributor._allgather(local_unique)
global_unique = np.concatenate(global_unique)
return np.unique(global_unique)
def bincount(self, weights = None, minlength = None):
if np.dtype(self.dtype).type not in [np.int8, np.int16, np.int32,
np.int64, np.uint8, np.uint16, np.uint32, np.uint64]:
raise TypeError(about._errors.cstring(
"ERROR: Distributed-data-object must be of integer datatype!"))
minlength = max(self.amax()+1, minlength)
if weights is not None:
local_weights = self.distributor.extract_local_data(weights).\
flatten()
else:
local_weights = None
local_counts = np.bincount(self.get_local_data().flatten(),
weights = local_weights,
minlength = minlength)
list_of_counts = self.distributor._allgather(local_counts)
print list_of_counts
counts = np.sum(list_of_counts, axis = 0)
return counts
def set_local_data(self, data, hermitian=False, copy=True):
......
......@@ -508,6 +508,45 @@ def direct_dot(x, y):
return np.vdot(x, y)
def convert_nested_list_to_object_array(x):
## if x is a nested_list full of ndarrays all having the same size,
## np.shape returns the shape of the ndarrays, too, i.e. too many
## dimensions
possible_shape = np.shape(x)
## Check if possible_shape goes too deep.
dimension_counter = 0
current_extract = x
for i in xrange(len(possible_shape)):
if isinstance(current_extract, list) == False and\
isinstance(current_extract, tuple) == False:
break
current_extract = current_extract[0]
dimension_counter += 1
real_shape = possible_shape[:dimension_counter]
## if the numpy array was not encapsulated at all, return x directly
if real_shape == ():
return x
## Prepare the carrier-object
carrier = np.empty(real_shape, dtype = np.object)
for i in xrange(np.prod(real_shape)):
ii = np.unravel_index(i, real_shape)
try:
carrier[ii] = x[ii]
except(TypeError):
extracted = x
for j in xrange(len(ii)):
extracted = extracted[ii[j]]
carrier[ii] = extracted
return carrier
......
This diff is collapsed.
......@@ -349,7 +349,7 @@ class rg_space(point_space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dof(self):
def get_dof(self):
"""
Computes the number of degrees of freedom of the space, i.e.\ the
number of grid points multiplied with one or two, depending on
......@@ -556,16 +556,8 @@ class rg_space(point_space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast(self, x, verbose = False):
if self.datamodel == 'd2o':
return self._cast_to_d2o(x = x, verbose = False)
elif self.datamodel == 'np':
return self._cast_to_np(x = x, verbose = False)
else:
raise NotImplementedError(about._errors.cstring(
"ERROR: function is not implemented for given datamodel."))
def _cast_to_d2o(self, x, verbose=False):
def _cast_to_d2o(self, x, dtype = None, ignore_complexity = False,
verbose=False, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
......@@ -588,6 +580,8 @@ class rg_space(point_space):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is None:
dtype = self.datatype
## Case 1: x is a field
if isinstance(x, field):
if verbose:
......@@ -596,14 +590,14 @@ class rg_space(point_space):
about.warnings.cflush(\
"WARNING: Getting data from foreign domain!")
## Extract the data, whatever it is, and cast it again
return self.cast(x.val)
return self.cast(x.val, dtype=dtype)
## Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object):
## Check the shape
if np.any(x.shape != self.get_shape()):
## Check if at least the number of degrees of freedom is equal
if x.dim() == self.get_dim():
if x.get_dim() == self.get_dim():
## If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(\
"WARNING: Trying to reshape the data. This operation is "+\
......@@ -611,24 +605,26 @@ class rg_space(point_space):
temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape())
## ... and cast again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
else:
raise ValueError(about._errors.cstring(\
"ERROR: Data has incompatible shape!"))
## Check the datatype
if x.dtype < self.datatype:
if np.dtype(x.dtype) != np.dtype(dtype):
if np.dtype(x.dtype) > np.dtype(dtype):
about.warnings.cflush(\
"WARNING: Datatypes are uneqal/of conflicting precision (own: "\
+ str(self.datatype) + " <> foreign: " + str(x.dtype) \
+ str(dtype) + " <> foreign: " + str(x.dtype) \
+ ") and will be casted! "\
+ "Potential loss of precision!\n")
temp = x.copy_empty(dtype=self.datatype)
temp = x.copy_empty(dtype=dtype)
temp.set_local_data(x.get_local_data())
temp.hermitian = x.hermitian
x = temp
if ignore_complexity == False:
## Check hermitianity/reality
if self.paradict['complexity'] == 0:
if x.iscomplex().any() == True:
......@@ -654,11 +650,12 @@ class rg_space(point_space):
## Case 3: x is something else
## Use general d2o casting
x = distributed_data_object(x, global_shape=self.get_shape(),\
dtype=self.datatype)
dtype=dtype)
## Cast the d2o
return self.cast(x)
return self.cast(x, dtype=dtype)
def _cast_to_np(self, x, verbose = False):
def _cast_to_np(self, x, dtype = None, ignore_complexity = False,
verbose = False, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
......@@ -681,6 +678,8 @@ class rg_space(point_space):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is None:
dtype = self.datatype
## Case 1: x is a field
if isinstance(x, field):
if verbose:
......@@ -689,14 +688,14 @@ class rg_space(point_space):
about.warnings.cflush(\
"WARNING: Getting data from foreign domain!")
## Extract the data, whatever it is, and cast it again
return self.cast(x.val)
return self.cast(x.val, dtype=dtype)
## Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object):
## Extract the data
temp = x.get_full_data()
## Cast the resulting numpy array again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
elif isinstance(x, np.ndarray):
## Check the shape
......@@ -706,28 +705,29 @@ class rg_space(point_space):
## If the number of dof is equal or 1, use np.reshape...
temp = x.reshape(self.get_shape())
## ... and cast again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
elif x.size == 1:
temp = np.empty(shape = self.get_shape(),
dtype = self.datatype)
dtype = dtype)
temp[:] = x
return self.cast(temp)
return self.cast(temp, dtype=dtype)
else:
raise ValueError(about._errors.cstring(\
"ERROR: Data has incompatible shape!"))
## Check the datatype
if x.dtype < self.datatype:
if x.dtype != dtype:
about.warnings.cflush(\
"WARNING: Datatypes are uneqal/of conflicting precision (own: "\
+ str(self.datatype) + " <> foreign: " + str(x.dtype) \
+ str(dtype) + " <> foreign: " + str(x.dtype) \
+ ") and will be casted! "\
+ "Potential loss of precision!\n")
## Fix the datatype...
temp = x.astype(self.datatype)
temp = x.astype(dtype)
##... and cast again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
if ignore_complexity == False:
## Check hermitianity/reality
if self.paradict['complexity'] == 0:
if not np.all(np.isreal(x)) == True:
......@@ -751,7 +751,7 @@ class rg_space(point_space):
## Case 3: x is something else
## Use general numpy casting
else:
temp = np.empty(self.get_shape(), dtype = self.datatype)
temp = np.empty(self.get_shape(), dtype = dtype)
temp[:] = x
return temp
......@@ -1513,7 +1513,7 @@ class rg_space(point_space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_power(self,x,**kwargs):
def calc_power(self, x, **kwargs):
"""
Computes the power of an array of field values.
......@@ -1559,7 +1559,7 @@ class rg_space(point_space):
## If self is a position space, delegate calc_power to its codomain.
if self.fourier == False:
try:
codomain = kwargs.get('codomain')
codomain = kwargs['codomain']
except(KeyError):
codomain = self.get_codomain()
......@@ -1589,6 +1589,7 @@ class rg_space(point_space):
fieldabs = abs(x)**2
power_spectrum = np.zeros(rho.shape)
## TODO: Rework to real numpy-ness
if self.datamodel == 'np':
working_field = pindex.copy_empty(dtype = fieldabs.dtype)
working_field.set_full_data(data = fieldabs)
......@@ -1599,7 +1600,8 @@ class rg_space(point_space):
pindex.distributor._allgather(local_power_spectrum)
power_spectrum = np.sum(power_spectrum, axis = 0)
if self.datamodel == 'd2o':
## TODO: Use d2o.bincount in order to simplify this
elif self.datamodel == 'd2o':
## In order to make the summation over identical pindices fast,
## the pindex and the kindex must have the same distribution strategy
if pindex.distribution_strategy == fieldabs.distribution_strategy and\
......@@ -1610,10 +1612,12 @@ class rg_space(point_space):
working_field.inject((slice(None),), fieldabs, (slice(None,)))
local_power_spectrum = np.bincount(pindex.get_local_data().flatten(),
weights = working_field.get_local_data().flatten())
weights = working_field.get_local_data().flatten(),
minlength = len(rho))
power_spectrum =\
pindex.distributor._allgather(local_power_spectrum)
power_spectrum = np.sum(power_spectrum, axis = 0)
else:
raise NotImplementedError(about._errors.cstring(
"ERROR: function is not implemented for given datamodel."))
......@@ -2201,6 +2205,7 @@ class power_indices(object):
## store the local pindex data in the global_pindex d2o
global_pindex.set_local_data(local_pindex)
## TODO: Use the universal capabilities of bincount in oder to speed this up!
#######
# rho #
#######
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment