Commit 23e0643e authored by Ultima's avatar Ultima
Browse files

Improved the multidimensionality of fields; a lot.

Added some ufuncs to point_space, field and distributed_data_object.
Updated the projection_operator.
parent 0f1e4c12
...@@ -170,7 +170,14 @@ class lm_space(point_space): ...@@ -170,7 +170,14 @@ class lm_space(point_space):
about.warnings.cprint("WARNING: data type set to default.") about.warnings.cprint("WARNING: data type set to default.")
datatype = np.complex128 datatype = np.complex128
self.datatype = datatype self.datatype = datatype
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = True self.discrete = True
self.vol = np.real(np.array([1],dtype=self.datatype)) self.vol = np.real(np.array([1],dtype=self.datatype))
...@@ -1008,7 +1015,7 @@ class gl_space(point_space): ...@@ -1008,7 +1015,7 @@ class gl_space(point_space):
vol : numpy.ndarray vol : numpy.ndarray
An array containing the pixel sizes. An array containing the pixel sizes.
""" """
def __init__(self,nlat,nlon=None,datatype=None): def __init__(self, nlat, nlon=None, datatype=None, datamodel='np'):
""" """
Sets the attributes for a gl_space class instance. Sets the attributes for a gl_space class instance.
...@@ -1047,6 +1054,13 @@ class gl_space(point_space): ...@@ -1047,6 +1054,13 @@ class gl_space(point_space):
about.warnings.cprint("WARNING: data type set to default.") about.warnings.cprint("WARNING: data type set to default.")
datatype = np.float64 datatype = np.float64
self.datatype = datatype self.datatype = datatype
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = False self.discrete = False
self.vol = gl.vol(self.paradict['nlat'],nlon=self.paradict['nlon']).astype(self.datatype) self.vol = gl.vol(self.paradict['nlat'],nlon=self.paradict['nlon']).astype(self.datatype)
...@@ -1702,7 +1716,7 @@ class hp_space(point_space): ...@@ -1702,7 +1716,7 @@ class hp_space(point_space):
""" """
niter = 0 ## default number of iterations used for transformations niter = 0 ## default number of iterations used for transformations
def __init__(self, nside): def __init__(self, nside, datamodel = 'np'):
""" """
Sets the attributes for a hp_space class instance. Sets the attributes for a hp_space class instance.
...@@ -1731,6 +1745,14 @@ class hp_space(point_space): ...@@ -1731,6 +1745,14 @@ class hp_space(point_space):
self.paradict = hp_space_paradict(nside=nside) self.paradict = hp_space_paradict(nside=nside)
self.datatype = np.float64 self.datatype = np.float64
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = False self.discrete = False
self.vol = np.array([4*pi/(12*self.paradict['nside']**2)],dtype=self.datatype) self.vol = np.array([4*pi/(12*self.paradict['nside']**2)],dtype=self.datatype)
......
This diff is collapsed.
...@@ -132,7 +132,7 @@ class distributed_data_object(object): ...@@ -132,7 +132,7 @@ class distributed_data_object(object):
except(AttributeError): except(AttributeError):
dtype = np.array(global_data).dtype.type dtype = np.array(global_data).dtype.type
else: else:
dtype = dtype dtype = np.dtype(dtype).type
## an explicitly given global_shape argument is only used if ## an explicitly given global_shape argument is only used if
## 1. no global_data was supplied, or ## 1. no global_data was supplied, or
...@@ -171,7 +171,7 @@ class distributed_data_object(object): ...@@ -171,7 +171,7 @@ class distributed_data_object(object):
## If the input data was a scalar, set the whole array to this value ## If the input data was a scalar, set the whole array to this value
elif global_data != None and np.isscalar(global_data): elif global_data != None and np.isscalar(global_data):
temp = np.empty(self.distributor.local_shape) temp = np.empty(self.distributor.local_shape, dtype = self.dtype)
temp.fill(global_data) temp.fill(global_data)
self.set_local_data(temp) self.set_local_data(temp)
self.hermitian = True self.hermitian = True
...@@ -243,12 +243,14 @@ class distributed_data_object(object): ...@@ -243,12 +243,14 @@ class distributed_data_object(object):
def __repr__(self): def __repr__(self):
return '<distributed_data_object>\n'+self.data.__repr__() return '<distributed_data_object>\n'+self.data.__repr__()
def __eq__(self, other):
def _compare_helper(self, other, op):
result = self.copy_empty(dtype = np.bool_) result = self.copy_empty(dtype = np.bool_)
## Case 1: 'other' is a scalar ## Case 1: 'other' is a scalar
## -> make point-wise comparison ## -> make point-wise comparison
if np.isscalar(other): if np.isscalar(other):
result.set_local_data(self.get_local_data(copy = False) == other) result.set_local_data(
getattr(self.get_local_data(copy = False), op)(other))
return result return result
## Case 2: 'other' is a numpy array or a distributed_data_object ## Case 2: 'other' is a numpy array or a distributed_data_object
...@@ -256,7 +258,8 @@ class distributed_data_object(object): ...@@ -256,7 +258,8 @@ class distributed_data_object(object):
elif isinstance(other, np.ndarray) or\ elif isinstance(other, np.ndarray) or\
isinstance(other, distributed_data_object): isinstance(other, distributed_data_object):
temp_data = self.distributor.extract_local_data(other) temp_data = self.distributor.extract_local_data(other)
result.set_local_data(self.get_local_data(copy=False) == temp_data) result.set_local_data(
getattr(self.get_local_data(copy=False), op)(temp_data))
return result return result
## Case 3: 'other' is None ## Case 3: 'other' is None
...@@ -267,11 +270,27 @@ class distributed_data_object(object): ...@@ -267,11 +270,27 @@ class distributed_data_object(object):
## -> make a numpy casting and make a recursive call ## -> make a numpy casting and make a recursive call
else: else:
temp_other = np.array(other) temp_other = np.array(other)
return self.__eq__(temp_other) return getattr(self, op)(temp_other)
def __ne__(self, other):
return self._compare_helper(other, '__ne__')
def __lt__(self, other):
return self._compare_helper(other, '__lt__')
def __le__(self, other):
return self._compare_helper(other, '__le__')
def __eq__(self, other):
return self._compare_helper(other, '__eq__')
def __ge__(self, other):
return self._compare_helper(other, '__ge__')
def __gt__(self, other):
return self._compare_helper(other, '__gt__')
def equal(self, other): def equal(self, other):
if other is None: if other is None:
return False return False
...@@ -448,7 +467,7 @@ class distributed_data_object(object): ...@@ -448,7 +467,7 @@ class distributed_data_object(object):
def __len__(self): def __len__(self):
return self.shape[0] return self.shape[0]
def dim(self): def get_dim(self):
return np.prod(self.shape) return np.prod(self.shape)
def vdot(self, other): def vdot(self, other):
...@@ -620,7 +639,34 @@ class distributed_data_object(object): ...@@ -620,7 +639,34 @@ class distributed_data_object(object):
global_any = self.distributor._allgather(local_any) global_any = self.distributor._allgather(local_any)
return np.any(global_any) return np.any(global_any)
def unique(self):
local_unique = np.unique(self.get_local_data())
global_unique = self.distributor._allgather(local_unique)
global_unique = np.concatenate(global_unique)
return np.unique(global_unique)
def bincount(self, weights = None, minlength = None):
if np.dtype(self.dtype).type not in [np.int8, np.int16, np.int32,
np.int64, np.uint8, np.uint16, np.uint32, np.uint64]:
raise TypeError(about._errors.cstring(
"ERROR: Distributed-data-object must be of integer datatype!"))
minlength = max(self.amax()+1, minlength)
if weights is not None:
local_weights = self.distributor.extract_local_data(weights).\
flatten()
else:
local_weights = None
local_counts = np.bincount(self.get_local_data().flatten(),
weights = local_weights,
minlength = minlength)
list_of_counts = self.distributor._allgather(local_counts)
print list_of_counts
counts = np.sum(list_of_counts, axis = 0)
return counts
def set_local_data(self, data, hermitian=False, copy=True): def set_local_data(self, data, hermitian=False, copy=True):
""" """
......
...@@ -508,7 +508,46 @@ def direct_dot(x, y): ...@@ -508,7 +508,46 @@ def direct_dot(x, y):
return np.vdot(x, y) return np.vdot(x, y)
def convert_nested_list_to_object_array(x):
## if x is a nested_list full of ndarrays all having the same size,
## np.shape returns the shape of the ndarrays, too, i.e. too many
## dimensions
possible_shape = np.shape(x)
## Check if possible_shape goes too deep.
dimension_counter = 0
current_extract = x
for i in xrange(len(possible_shape)):
if isinstance(current_extract, list) == False and\
isinstance(current_extract, tuple) == False:
break
current_extract = current_extract[0]
dimension_counter += 1
real_shape = possible_shape[:dimension_counter]
## if the numpy array was not encapsulated at all, return x directly
if real_shape == ():
return x
## Prepare the carrier-object
carrier = np.empty(real_shape, dtype = np.object)
for i in xrange(np.prod(real_shape)):
ii = np.unravel_index(i, real_shape)
try:
carrier[ii] = x[ii]
except(TypeError):
extracted = x
for j in xrange(len(ii)):
extracted = extracted[ii[j]]
carrier[ii] = extracted
return carrier
......
This diff is collapsed.
...@@ -349,7 +349,7 @@ class rg_space(point_space): ...@@ -349,7 +349,7 @@ class rg_space(point_space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dof(self): def get_dof(self):
""" """
Computes the number of degrees of freedom of the space, i.e.\ the Computes the number of degrees of freedom of the space, i.e.\ the
number of grid points multiplied with one or two, depending on number of grid points multiplied with one or two, depending on
...@@ -556,16 +556,8 @@ class rg_space(point_space): ...@@ -556,16 +556,8 @@ class rg_space(point_space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast(self, x, verbose = False): def _cast_to_d2o(self, x, dtype = None, ignore_complexity = False,
if self.datamodel == 'd2o': verbose=False, **kwargs):
return self._cast_to_d2o(x = x, verbose = False)
elif self.datamodel == 'np':
return self._cast_to_np(x = x, verbose = False)
else:
raise NotImplementedError(about._errors.cstring(
"ERROR: function is not implemented for given datamodel."))
def _cast_to_d2o(self, x, verbose=False):
""" """
Computes valid field values from a given object, trying Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as to translate the given data into a valid form. Thereby it is as
...@@ -588,6 +580,8 @@ class rg_space(point_space): ...@@ -588,6 +580,8 @@ class rg_space(point_space):
Whether the method should raise a warning if information is Whether the method should raise a warning if information is
lost during casting (default: False). lost during casting (default: False).
""" """
if dtype is None:
dtype = self.datatype
## Case 1: x is a field ## Case 1: x is a field
if isinstance(x, field): if isinstance(x, field):
if verbose: if verbose:
...@@ -596,14 +590,14 @@ class rg_space(point_space): ...@@ -596,14 +590,14 @@ class rg_space(point_space):
about.warnings.cflush(\ about.warnings.cflush(\
"WARNING: Getting data from foreign domain!") "WARNING: Getting data from foreign domain!")
## Extract the data, whatever it is, and cast it again ## Extract the data, whatever it is, and cast it again
return self.cast(x.val) return self.cast(x.val, dtype=dtype)
## Case 2: x is a distributed_data_object ## Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object): if isinstance(x, distributed_data_object):
## Check the shape ## Check the shape
if np.any(x.shape != self.get_shape()): if np.any(x.shape != self.get_shape()):
## Check if at least the number of degrees of freedom is equal ## Check if at least the number of degrees of freedom is equal
if x.dim() == self.get_dim(): if x.get_dim() == self.get_dim():
## If the number of dof is equal or 1, use np.reshape... ## If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(\ about.warnings.cflush(\
"WARNING: Trying to reshape the data. This operation is "+\ "WARNING: Trying to reshape the data. This operation is "+\
...@@ -611,54 +605,57 @@ class rg_space(point_space): ...@@ -611,54 +605,57 @@ class rg_space(point_space):
temp = x.get_full_data() temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape()) temp = np.reshape(temp, self.get_shape())
## ... and cast again ## ... and cast again
return self.cast(temp) return self.cast(temp, dtype=dtype)
else: else:
raise ValueError(about._errors.cstring(\ raise ValueError(about._errors.cstring(\
"ERROR: Data has incompatible shape!")) "ERROR: Data has incompatible shape!"))
## Check the datatype ## Check the datatype
if x.dtype < self.datatype: if np.dtype(x.dtype) != np.dtype(dtype):
about.warnings.cflush(\ if np.dtype(x.dtype) > np.dtype(dtype):
about.warnings.cflush(\
"WARNING: Datatypes are uneqal/of conflicting precision (own: "\ "WARNING: Datatypes are uneqal/of conflicting precision (own: "\
+ str(self.datatype) + " <> foreign: " + str(x.dtype) \ + str(dtype) + " <> foreign: " + str(x.dtype) \
+ ") and will be casted! "\ + ") and will be casted! "\
+ "Potential loss of precision!\n") + "Potential loss of precision!\n")
temp = x.copy_empty(dtype=self.datatype) temp = x.copy_empty(dtype=dtype)
temp.set_local_data(x.get_local_data()) temp.set_local_data(x.get_local_data())
temp.hermitian = x.hermitian temp.hermitian = x.hermitian
x = temp x = temp
## Check hermitianity/reality if ignore_complexity == False:
if self.paradict['complexity'] == 0: ## Check hermitianity/reality
if x.iscomplex().any() == True: if self.paradict['complexity'] == 0:
about.warnings.cflush(\ if x.iscomplex().any() == True:
"WARNING: Data is not completely real. Imaginary part "+\ about.warnings.cflush(\
"will be discarded!\n") "WARNING: Data is not completely real. Imaginary part "+\
temp = x.copy_empty() "will be discarded!\n")
temp.set_local_data(np.real(x.get_local_data())) temp = x.copy_empty()
x = temp temp.set_local_data(np.real(x.get_local_data()))
x = temp
elif self.paradict['complexity'] == 1:
if x.hermitian == False and about.hermitianize.status == True:
about.warnings.cflush(\
"WARNING: Data gets hermitianized. This operation is "+\
"extremely expensive\n")
#temp = x.copy_empty()
#temp.set_full_data(gp.nhermitianize_fast(x.get_full_data(),
# (False, )*len(x.shape)))
x = utilities.hermitianize(x)
elif self.paradict['complexity'] == 1:
if x.hermitian == False and about.hermitianize.status == True:
about.warnings.cflush(\
"WARNING: Data gets hermitianized. This operation is "+\
"extremely expensive\n")
#temp = x.copy_empty()
#temp.set_full_data(gp.nhermitianize_fast(x.get_full_data(),
# (False, )*len(x.shape)))
x = utilities.hermitianize(x)
return x return x
## Case 3: x is something else ## Case 3: x is something else
## Use general d2o casting ## Use general d2o casting
x = distributed_data_object(x, global_shape=self.get_shape(),\ x = distributed_data_object(x, global_shape=self.get_shape(),\
dtype=self.datatype) dtype=dtype)
## Cast the d2o ## Cast the d2o
return self.cast(x) return self.cast(x, dtype=dtype)
def _cast_to_np(self, x, verbose = False): def _cast_to_np(self, x, dtype = None, ignore_complexity = False,
verbose = False, **kwargs):
""" """
Computes valid field values from a given object, trying Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as to translate the given data into a valid form. Thereby it is as
...@@ -681,6 +678,8 @@ class rg_space(point_space): ...@@ -681,6 +678,8 @@ class rg_space(point_space):
Whether the method should raise a warning if information is Whether the method should raise a warning if information is
lost during casting (default: False). lost during casting (default: False).
""" """
if dtype is None:
dtype = self.datatype
## Case 1: x is a field ## Case 1: x is a field
if isinstance(x, field): if isinstance(x, field):
if verbose: if verbose:
...@@ -689,14 +688,14 @@ class rg_space(point_space): ...@@ -689,14 +688,14 @@ class rg_space(point_space):
about.warnings.cflush(\ about.warnings.cflush(\
"WARNING: Getting data from foreign domain!") "WARNING: Getting data from foreign domain!")
## Extract the data, whatever it is, and cast it again ## Extract the data, whatever it is, and cast it again
return self.cast(x.val) return self.cast(x.val, dtype=dtype)
## Case 2: x is a distributed_data_object ## Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object): if isinstance(x, distributed_data_object):
## Extract the data ## Extract the data
temp = x.get_full_data() temp = x.get_full_data()
## Cast the resulting numpy array again ## Cast the resulting numpy array again
return self.cast(temp) return self.cast(temp, dtype=dtype)
elif isinstance(x, np.ndarray): elif isinstance(x, np.ndarray):
## Check the shape ## Check the shape
...@@ -706,52 +705,53 @@ class rg_space(point_space): ...@@ -706,52 +705,53 @@ class rg_space(point_space):
## If the number of dof is equal or 1, use np.reshape... ## If the number of dof is equal or 1, use np.reshape...
temp = x.reshape(self.get_shape()) temp = x.reshape(self.get_shape())
## ... and cast again ## ... and cast again
return self.cast(temp) return self.cast(temp, dtype=dtype)
elif x.size == 1: elif x.size == 1:
temp = np.empty(shape = self.get_shape(), temp = np.empty(shape = self.get_shape(),
dtype = self.datatype) dtype = dtype)
temp[:] = x temp[:] = x
return self.cast(temp) return self.cast(temp, dtype=dtype)
else: else:
raise ValueError(about._errors.cstring(\ raise ValueError(about._errors.cstring(\
"ERROR: Data has incompatible shape!")) "ERROR: Data has incompatible shape!"))
## Check the datatype ## Check the datatype
if x.dtype < self.datatype: if x.dtype != dtype:
about.warnings.cflush(\ about.warnings.cflush(\
"WARNING: Datatypes are uneqal/of conflicting precision (own: "\ "WARNING: Datatypes are uneqal/of conflicting precision (own: "\
+ str(self.datatype) + " <> foreign: " + str(x.dtype) \ + str(dtype) + " <> foreign: " + str(x.dtype) \
+ ") and will be casted! "\ + ") and will be casted! "\
+ "Potential loss of precision!\n") + "Potential loss of precision!\n")
## Fix the datatype... ## Fix the datatype...
temp = x.astype(self.datatype) temp = x.astype(dtype)
##... and cast again ##... and cast again
return self.cast(temp) return self.cast(temp, dtype=dtype)
## Check hermitianity/reality
if self.paradict['complexity'] == 0:
if not np.all(np.isreal(x)) == True:
about.warnings.cflush(\
"WARNING: Data is not completely real. Imaginary part "+\
"will be discarded!\n")
x = np.real(x)
elif self.paradict['complexity'] == 1: if ignore_complexity == False:
if about.hermitianize.status == True: ## Check hermitianity/reality
about.warnings.cflush(\ if self.paradict['complexity'] == 0:
"WARNING: Data gets hermitianized. This operation is "+\ if not np.all(np.isreal(x)) == True:
"rather expensive.\n") about.warnings.cflush(\
#temp = x.copy_empty() "WARNING: Data is not completely real. Imaginary part "+\
#temp.set_full_data(gp.nhermitianize_fast(x.get_full_data(), "will be discarded!\n")
# (False, )*len(x.shape))) x = np.real(x)
x = utilities.hermitianize(x)
elif self.paradict['complexity'] == 1:
if about.hermitianize.status == True:
about.warnings.cflush(\
"WARNING: Data gets hermitianized. This operation is "+\
"rather expensive.\n")
#temp = x.copy_empty()
#temp.set_full_data(gp.nhermitianize_fast(x.get_full_data(),
# (False, )*len(x.shape)))
x = utilities.hermitianize(x)
return x return x
## Case 3: x is something else ## Case 3: x is something else
## Use general numpy casting ## Use general numpy casting
else: else:
temp = np.empty(self.get_shape(), dtype = self.datatype) temp = np.empty(self.get_shape(), dtype = dtype)
temp[:] = x temp[:] = x
return temp return temp
...@@ -1513,7 +1513,7 @@ class rg_space(point_space): ...@@ -1513,7 +1513,7 @@ class rg_space(point_space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++