Commit 6c50de1e authored by Ultima's avatar Ultima
Browse files

Added first tests for point_space.

parent a041bd3f
......@@ -2,15 +2,18 @@
import numpy as np
def MIN():
return np.min
def MAX():
return np.max
def SUM():
return np.sum
#def MIN():
# return np.min
#
#def MAX():
# return np.max
#
#def SUM():
# return np.sum
MIN = np.min
MAX = np.max
SUM = np.sum
class _COMM_WORLD():
......@@ -27,7 +30,7 @@ class _COMM_WORLD():
def _scattergather_helper(self, sendbuf, recvbuf=None, **kwargs):
sendbuf = self._unwrapper(sendbuf)
recvbuf = self._unwrapper(recvbuf)
if recvbuf != None:
if recvbuf is not None:
recvbuf[:] = sendbuf
return recvbuf
else:
......@@ -50,7 +53,7 @@ class _COMM_WORLD():
return self._scattergather_helper(*args, **kwargs)
def gather(self, sendbuf, *args, **kwargs):
return [sendbuf,]
return [sendbuf]
def Gather(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs)
......@@ -59,7 +62,7 @@ class _COMM_WORLD():
return self._scattergather_helper(*args, **kwargs)
def allgather(self, sendbuf, *args, **kwargs):
return [sendbuf,]
return [sendbuf]
def Allgather(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs)
......@@ -87,6 +90,7 @@ class _COMM_WORLD():
def Barrier(self):
pass
class _datatype():
def __init__(self, name):
self.name = str(name)
......@@ -108,14 +112,3 @@ DOUBLE_COMPLEX = _datatype("MPI_DOUBLE_COMPLEX")
COMM_WORLD = _COMM_WORLD()
......@@ -103,6 +103,9 @@ class configuration(object):
for key, item in self.variable_dict.items():
item.set_value(None)
def validQ(self, name, value):
return self.variable_dict[name].checker(value)
def save(self, path=None, path_section=None):
if path is None:
if self.path is None:
......
......@@ -20,11 +20,11 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
from nifty.keepers import about
from nifty.keepers import about,\
global_dependency_injector as gdi
from distutils.version import LooseVersion as lv
try:
import libsharp_wrapper_gl as gl
except(ImportError):
......
......@@ -1087,7 +1087,7 @@ class point_space(space):
Pixel volume of the :py:class:`point_space`, which is always 1.
"""
def __init__(self, num, datatype=None, datamodel='fftw'):
def __init__(self, num, datatype=np.dtype('float'), datamodel='fftw'):
"""
Sets the attributes for a point_space class instance.
......@@ -1104,25 +1104,25 @@ class point_space(space):
"""
self.paradict = point_space_paradict(num=num)
# check datatype
if (datatype is None):
datatype = np.float64
elif (datatype not in [np.bool_,
np.int8,
np.int16,
np.int32,
np.int64,
np.float16,
np.float32,
np.float64,
np.complex64,
np.complex128]):
about.warnings.cprint("WARNING: data type set to default.")
datatype = np.float64
self.datatype = datatype
# parse datatype
dtype = np.dtype(datatype)
if dtype not in [np.dtype('bool'),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
np.dtype('complex64'),
np.dtype('complex128')]:
raise ValueError(about._errors.cstring(
"WARNING: incompatible datatype: " + str(dtype)))
self.dtype = dtype
self.datatype = dtype.type
if datamodel not in ['np'] + POINT_DISTRIBUTION_STRATEGIES:
about.warnings.cprint("WARNING: datamodel set to default.")
about._errors.cstring("WARNING: datamodel set to default.")
self.datamodel = \
global_configuration['default_distribution_strategy']
else:
......@@ -1139,7 +1139,7 @@ class point_space(space):
@para.setter
def para(self, x):
self.paradict['num'] = x
self.paradict['num'] = x[0]
def copy(self):
return point_space(num=self.paradict['num'],
......@@ -1442,21 +1442,34 @@ class point_space(space):
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast(self, x, dtype=None, verbose=False, **kwargs):
def cast(self, x, dtype=None, **kwargs):
if dtype is not None:
dtype = np.dtype(dtype).type
# If x is a field, extract the data and do a recursive call
if isinstance(x, field):
# Check if the domain matches
if self != x.domain:
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val,
dtype=dtype,
**kwargs)
if self.datamodel in POINT_DISTRIBUTION_STRATEGIES:
return self._cast_to_d2o(x=x, dtype=dtype, verbose=verbose,
return self._cast_to_d2o(x=x,
dtype=dtype,
**kwargs)
elif self.datamodel == 'np':
return self._cast_to_np(x=x, dtype=dtype, verbose=verbose,
return self._cast_to_np(x=x,
dtype=dtype,
**kwargs)
else:
raise NotImplementedError(about._errors.cstring(
"ERROR: function is not implemented for given datamodel."))
def _cast_to_d2o(self, x, dtype=None, verbose=False, **kwargs):
def _cast_to_d2o(self, x, dtype=None, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
......@@ -1482,30 +1495,25 @@ class point_space(space):
if dtype is None:
dtype = self.datatype
# Case 1: x is a field
if isinstance(x, field):
if verbose:
# Check if the domain matches
if(self != x.domain):
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val, dtype=dtype)
# Case 2: x is a distributed_data_object
# Case 1: x is a distributed_data_object
if isinstance(x, distributed_data_object):
to_copy = False
# Check the shape
if np.any(x.shape != self.get_shape()):
# Check if at least the number of degrees of freedom is equal
if x.get_dim() == self.get_dim():
# If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(
"WARNING: Trying to reshape the data. This operation is " +
"expensive as it consolidates the full data!\n")
"WARNING: Trying to reshape the data. This " +
"operation is expensive as it consolidates the " +
"full data!\n")
temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape())
# ... and cast again
return self.cast(temp, dtype=dtype)
return self._cast_to_d2o(temp,
dtype=dtype,
**kwargs)
else:
raise ValueError(about._errors.cstring(
......@@ -1514,25 +1522,35 @@ class point_space(space):
# Check the datatype
if x.dtype != dtype:
about.warnings.cflush(
"WARNING: Datatypes are uneqal/of conflicting precision (own: "
+ str(dtype) + " <> foreign: " + str(x.dtype)
+ ") and will be casted! "
+ "Potential loss of precision!\n")
temp = x.copy_empty(dtype=dtype)
"WARNING: Datatypes are uneqal/of conflicting precision " +
"(own: " + str(dtype) + " <> foreign: " + str(x.dtype) +
") and will be casted! Potential loss of precision!\n")
to_copy = True
# Check the distribution_strategy
if x.distribution_strategy != self.datamodel:
to_copy = True
if to_copy:
temp = x.copy_empty(dtype=dtype,
distribution_strategy=self.datamodel)
temp.set_local_data(x.get_local_data())
temp.hermitian = x.hermitian
x = temp
return x
# Case 3: x is something else
# Case 2: x is something else
# Use general d2o casting
x = distributed_data_object(x, global_shape=self.get_shape(),
dtype=dtype)
# Cast the d2o
return self.cast(x, dtype=dtype)
else:
x = distributed_data_object(x,
global_shape=self.get_shape(),
dtype=dtype,
distribution_strategy=self.datamodel)
# Cast the d2o
return self.cast(x, dtype=dtype)
def _cast_to_np(self, x, dtype=None, verbose=False, **kwargs):
def _cast_to_np(self, x, dtype=None, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
......@@ -1558,23 +1576,16 @@ class point_space(space):
if dtype is None:
dtype = self.datatype
# Case 1: x is a field
if isinstance(x, field):
if verbose:
# Check if the domain matches
if(self != x.domain):
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val, dtype=dtype)
# Case 2: x is a distributed_data_object
# Case 1: x is a distributed_data_object
if isinstance(x, distributed_data_object):
# Extract the data
temp = x.get_full_data()
# Cast the resulting numpy array again
return self.cast(temp, dtype=dtype)
return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
# Case 2: x is a distributed_data_object
elif isinstance(x, np.ndarray):
# Check the shape
if np.any(x.shape != self.get_shape()):
......@@ -1583,12 +1594,16 @@ class point_space(space):
# If the number of dof is equal or 1, use np.reshape...
temp = x.reshape(self.get_shape())
# ... and cast again
return self.cast(temp, dtype=dtype)
return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
elif x.size == 1:
temp = np.empty(shape=self.get_shape(),
dtype=dtype)
temp[:] = x
return self.cast(temp, dtype=dtype)
return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
else:
raise ValueError(about._errors.cstring(
"ERROR: Data has incompatible shape!"))
......@@ -1596,14 +1611,15 @@ class point_space(space):
# Check the datatype
if x.dtype != dtype:
about.warnings.cflush(
"WARNING: Datatypes are uneqal/of conflicting precision (own: "
+ str(dtype) + " <> foreign: " + str(x.dtype)
+ ") and will be casted! "
+ "Potential loss of precision!\n")
"WARNING: Datatypes are uneqal/of conflicting precision " +
" (own: " + str(dtype) + " <> foreign: " + str(x.dtype) +
") and will be casted! Potential loss of precision!\n")
# Fix the datatype...
temp = x.astype(dtype)
# ... and cast again
return self.cast(temp, dtype=dtype)
return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
return x
......@@ -1613,7 +1629,7 @@ class point_space(space):
temp = np.empty(self.get_shape(), dtype=dtype)
if x is not None:
temp[:] = x
return temp
return self._cast_to_np(temp)
def enforce_shape(self, x):
"""
......
......@@ -2668,33 +2668,6 @@ class _dtype_converter(object):
"""
def __init__(self):
# pre_dict = [
# #[, MPI_CHAR],
# #[, MPI_SIGNED_CHAR],
# #[, MPI_UNSIGNED_CHAR],
# [np.bool_, MPI.BYTE],
# [np.int16, MPI.SHORT],
# [np.uint16, MPI.UNSIGNED_SHORT],
# [np.uint32, MPI.UNSIGNED_INT],
# [np.int32, MPI.INT],
# [np.int, MPI.LONG],
# [np.int_, MPI.LONG],
# [np.int64, MPI.LONG],
# [np.long, MPI.LONG],
# [np.longlong, MPI.LONG_LONG],
# [np.uint64, MPI.UNSIGNED_LONG],
# [np.ulonglong, MPI.UNSIGNED_LONG_LONG],
# [np.int64, MPI.LONG_LONG],
# [np.uint64, MPI.UNSIGNED_LONG_LONG],
# [np.float32, MPI.FLOAT],
# [np.float, MPI.DOUBLE],
# [np.float_, MPI.DOUBLE],
# [np.float64, MPI.DOUBLE],
# [np.float128, MPI.LONG_DOUBLE],
# [np.complex64, MPI.COMPLEX],
# [np.complex, MPI.DOUBLE_COMPLEX],
# [np.complex_, MPI.DOUBLE_COMPLEX],
# [np.complex128, MPI.DOUBLE_COMPLEX]]
pre_dict = [
# [, MPI_CHAR],
# [, MPI_SIGNED_CHAR],
......
......@@ -8,30 +8,17 @@ Created on Thu Apr 2 21:29:30 2015
import numpy as np
from keepers import about
"""
def paradict_getter(space_instance):
paradict_dictionary = {
str(space().__class__) : _space_paradict,
str(point_space((2)).__class__) : _point_space_paradict,
str(rg_space((2)).__class__) : _rg_space_paradict,
str(nested_space([point_space(2), point_space(2)]).__class__) : _nested_space_paradict,
str(lm_space(1).__class__) : _lm_space_paradict,
str(gl_space(2).__class__) : _gl_space_paradict,
str(hp_space(1).__class__) : _hp_space_paradict,
}
return paradict_dictionary[str(space_instance.__class__)]()
"""
class space_paradict(object):
def __init__(self, **kwargs):
self.parameters = {}
for key in kwargs:
self[key] = kwargs[key]
def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
......@@ -41,303 +28,339 @@ class space_paradict(object):
def __setitem__(self, key, arg):
if(np.isscalar(arg)):
arg = np.array([arg],dtype=np.int)
arg = np.array([arg], dtype=np.int)
else:
arg = np.array(arg,dtype=np.int)
arg = np.array(arg, dtype=np.int)
self.parameters.__setitem__(key, arg)
def __getitem__(self, key):
return self.parameters.__getitem__(key)
class point_space_paradict(space_paradict):
def __setitem__(self, key, arg):
if key is not 'num':
raise ValueError(about._errors.cstring("ERROR: Unsupported point_space parameter"))
temp = np.array(arg, dtype=int).flatten()[0]
raise ValueError(about._errors.cstring(
"ERROR: Unsupported point_space parameter"))
if not np.isscalar(arg):
raise ValueError(about._errors.cstring(
"ERROR: 'num' parameter must be scalar. Got: " + str(arg)))
if abs(arg) != arg:
raise ValueError(about._errors.cstring(
"ERROR: 'num' parameter must be positive. Got: " + str(arg)))
temp = np.int(arg)
self.parameters.__setitem__(key, temp)
class rg_space_paradict(space_paradict):
def __init__(self, num, complexity=2, zerocenter=False):
self.ndim = len(np.array(num).flatten())
space_paradict.__init__(self, num=num, complexity=complexity, zerocenter=zerocenter)
space_paradict.__init__(
self, num=num, complexity=complexity, zerocenter=zerocenter)
def __setitem__(self, key, arg):
if key not in ['num', 'complexity', 'zerocenter']:
raise ValueError(about._errors.cstring("ERROR: Unsupported rg_space parameter"))
raise ValueError(about._errors.cstring(
"ERROR: Unsupported rg_space parameter"))
if key == 'num':
temp = list(np.array(arg, dtype=int).flatten())
if len(temp) != self.ndim:
raise ValueError(about._errors.cstring("ERROR: Number of dimensions does not match the init value."))
raise ValueError(about._errors.cstring(
"ERROR: Number of dimensions does not match the init " +
"value."))
elif key == 'complexity':
temp = int(arg)
elif key == 'zerocenter':
temp = np.empty(self.ndim, dtype=bool)
temp[:] = arg
temp = list(temp)
#if len(temp) != self.ndim:
# raise ValueError(about._errors.cstring("ERROR: Number of dimensions does not match the init value."))
self.parameters.__setitem__(key, temp)
class nested_space_paradict(space_paradict):
def __init__(self, ndim):
self.ndim = np.int(ndim)
space_paradict.__init__(self)
def __setitem__(self, key, arg):
if not isinstance(key, int):
raise ValueError(about._errors.cstring("ERROR: Unsupported point_space parameter"))
raise ValueError(about._errors.cstring(
"ERROR: Unsupported point_space parameter"))
if key >= self.ndim or key < 0:
raise ValueError(about._errors.cstring("ERROR: Nestindex out of bounds"))
raise ValueError(about._errors.cstring(
"ERROR: Nestindex out of bounds"))
temp = list(np.array(arg, dtype=int).flatten())
self.parameters.__setitem__(key, temp)
class lm_space_paradict(space_paradict):
def __init__(self, lmax, mmax=None):
space_paradict.__init__(self, lmax=lmax)
if mmax == None:
mmax = -1
if mmax is None:
mmax = -1
self['mmax'] = mmax
def __setitem__(self, key, arg):
if key not in ['lmax', 'mmax']:
raise ValueError(about._errors.cstring("ERROR: Unsupported rg_space parameter"))
raise ValueError(about._errors.cstring(
"ERROR: Unsupported rg_space parameter"))
if key == 'lmax':
temp = int(arg)
if(temp<1):
raise ValueError(about._errors.cstring("ERROR: lmax: nonpositive number."))
if (temp%2 == 0) and (temp > 2): ## exception lmax == 2 (nside == 1)
about.warnings.cprint("WARNING: unrecommended parameter (lmax <> 2*n+1).")
if(temp < 1):
raise ValueError(about._errors.cstring(
"ERROR: lmax: nonpositive number."))
# exception lmax == 2 (nside == 1)
if (temp % 2 == 0) and (temp > 2):
about.warnings.cprint(
"WARNING: unrecommended parameter (lmax <> 2*n+1).")
try:
if temp < self['mmax']:
about.warnings.cprint("WARNING: mmax parameter set to lmax.")
about.warnings.cprint(
"WARNING: mmax parameter set to lmax.")
self['mmax'] = temp
if (temp != self['mmax']):
about.warnings.cprint("WARNING: unrecommended parameter set (mmax <> lmax).")
about.warnings.cprint(
"WARNING: unrecommended parameter set (mmax <> lmax).")
except:
pass
elif key == 'mmax':
temp = int(arg)
if (temp < 1) or(temp > self['lmax']):
about.warnings.cprint("WARNING: mmax parameter set to default.")
about.warnings.cprint(
"WARNING: mmax parameter set to default.")
temp = self['lmax']
if(temp != self['lmax']):
about.warnings.cprint("WARNING: unrecommended parameter set (mmax <> lmax).")
about.warnings.cprint(
"WARNING: unrecommended parameter set (mmax <> lmax).")
self.parameters.__setitem__(key, temp)
class gl_space_paradict(space_paradict):
def __init__(self, nlat, nlon=None):
space_paradict.__init__(self, nlat=nlat)
if nlon == None:
nlon = -1
if nlon is None:
nlon = -1
self['nlon'] = nlon
def __setitem__(self, key, arg):
if key not in ['nlat', 'nlon']:
raise ValueError(about._errors.cstring("ERROR: Unsupported rg_space parameter"))
raise ValueError(about._errors.cstring(
"ERROR: Unsupported rg_space parameter"))
if key == 'nlat':
temp = int(arg)
if(temp<1):
raise ValueError(about._errors.cstring("ERROR: nlat: nonpositive number."))
if (temp%2 != 0):
raise ValueError(about._errors.cstring("ERROR: invalid parameter (nlat <> 2n)."))
if(temp < 1):
raise ValueError(about._errors.cstring(
"ERROR: nlat: nonpositive number."))
if (temp % 2 != 0):
raise ValueError(about._errors.cstring(
"ERROR: invalid parameter (nlat <> 2n)."))
try:
if temp < self['mmax']:
about.warnings.cprint("WARNING: mmax parameter set to lmax.")
about.warnings.cprint(
"WARNING: mmax parameter set to lmax.")
self['mmax'] = temp
if (temp != self['mmax']):
about.warnings.cprint("WARNING: unrecommended parameter set (mmax <> lmax).")
about.warnings.cprint(
"WARNING: unrecommended parameter set (mmax <> lmax).")
except:
pass
elif key == 'nlon':
temp = int(arg)
if (temp < 1):
about.warnings.</