Commit 6c50de1e authored by Ultima's avatar Ultima
Browse files

Added first tests for point_space.

parent a041bd3f
...@@ -2,15 +2,18 @@ ...@@ -2,15 +2,18 @@
import numpy as np import numpy as np
def MIN(): #def MIN():
return np.min # return np.min
#
def MAX(): #def MAX():
return np.max # return np.max
#
def SUM(): #def SUM():
return np.sum # return np.sum
MIN = np.min
MAX = np.max
SUM = np.sum
class _COMM_WORLD(): class _COMM_WORLD():
...@@ -27,7 +30,7 @@ class _COMM_WORLD(): ...@@ -27,7 +30,7 @@ class _COMM_WORLD():
def _scattergather_helper(self, sendbuf, recvbuf=None, **kwargs): def _scattergather_helper(self, sendbuf, recvbuf=None, **kwargs):
sendbuf = self._unwrapper(sendbuf) sendbuf = self._unwrapper(sendbuf)
recvbuf = self._unwrapper(recvbuf) recvbuf = self._unwrapper(recvbuf)
if recvbuf != None: if recvbuf is not None:
recvbuf[:] = sendbuf recvbuf[:] = sendbuf
return recvbuf return recvbuf
else: else:
...@@ -50,7 +53,7 @@ class _COMM_WORLD(): ...@@ -50,7 +53,7 @@ class _COMM_WORLD():
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
def gather(self, sendbuf, *args, **kwargs): def gather(self, sendbuf, *args, **kwargs):
return [sendbuf,] return [sendbuf]
def Gather(self, *args, **kwargs): def Gather(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
...@@ -59,7 +62,7 @@ class _COMM_WORLD(): ...@@ -59,7 +62,7 @@ class _COMM_WORLD():
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
def allgather(self, sendbuf, *args, **kwargs): def allgather(self, sendbuf, *args, **kwargs):
return [sendbuf,] return [sendbuf]
def Allgather(self, *args, **kwargs): def Allgather(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
...@@ -87,6 +90,7 @@ class _COMM_WORLD(): ...@@ -87,6 +90,7 @@ class _COMM_WORLD():
def Barrier(self): def Barrier(self):
pass pass
class _datatype(): class _datatype():
def __init__(self, name): def __init__(self, name):
self.name = str(name) self.name = str(name)
...@@ -108,14 +112,3 @@ DOUBLE_COMPLEX = _datatype("MPI_DOUBLE_COMPLEX") ...@@ -108,14 +112,3 @@ DOUBLE_COMPLEX = _datatype("MPI_DOUBLE_COMPLEX")
COMM_WORLD = _COMM_WORLD() COMM_WORLD = _COMM_WORLD()
...@@ -103,6 +103,9 @@ class configuration(object): ...@@ -103,6 +103,9 @@ class configuration(object):
for key, item in self.variable_dict.items(): for key, item in self.variable_dict.items():
item.set_value(None) item.set_value(None)
def validQ(self, name, value):
return self.variable_dict[name].checker(value)
def save(self, path=None, path_section=None): def save(self, path=None, path_section=None):
if path is None: if path is None:
if self.path is None: if self.path is None:
......
...@@ -20,11 +20,11 @@ ...@@ -20,11 +20,11 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>. ## along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division from __future__ import division
from nifty.keepers import about from nifty.keepers import about,\
global_dependency_injector as gdi
from distutils.version import LooseVersion as lv from distutils.version import LooseVersion as lv
try: try:
import libsharp_wrapper_gl as gl import libsharp_wrapper_gl as gl
except(ImportError): except(ImportError):
......
...@@ -1087,7 +1087,7 @@ class point_space(space): ...@@ -1087,7 +1087,7 @@ class point_space(space):
Pixel volume of the :py:class:`point_space`, which is always 1. Pixel volume of the :py:class:`point_space`, which is always 1.
""" """
def __init__(self, num, datatype=None, datamodel='fftw'): def __init__(self, num, datatype=np.dtype('float'), datamodel='fftw'):
""" """
Sets the attributes for a point_space class instance. Sets the attributes for a point_space class instance.
...@@ -1104,25 +1104,25 @@ class point_space(space): ...@@ -1104,25 +1104,25 @@ class point_space(space):
""" """
self.paradict = point_space_paradict(num=num) self.paradict = point_space_paradict(num=num)
# check datatype # parse datatype
if (datatype is None): dtype = np.dtype(datatype)
datatype = np.float64 if dtype not in [np.dtype('bool'),
elif (datatype not in [np.bool_, np.dtype('int8'),
np.int8, np.dtype('int16'),
np.int16, np.dtype('int32'),
np.int32, np.dtype('int64'),
np.int64, np.dtype('float16'),
np.float16, np.dtype('float32'),
np.float32, np.dtype('float64'),
np.float64, np.dtype('complex64'),
np.complex64, np.dtype('complex128')]:
np.complex128]): raise ValueError(about._errors.cstring(
about.warnings.cprint("WARNING: data type set to default.") "WARNING: incompatible datatype: " + str(dtype)))
datatype = np.float64 self.dtype = dtype
self.datatype = datatype self.datatype = dtype.type
if datamodel not in ['np'] + POINT_DISTRIBUTION_STRATEGIES: if datamodel not in ['np'] + POINT_DISTRIBUTION_STRATEGIES:
about.warnings.cprint("WARNING: datamodel set to default.") about._errors.cstring("WARNING: datamodel set to default.")
self.datamodel = \ self.datamodel = \
global_configuration['default_distribution_strategy'] global_configuration['default_distribution_strategy']
else: else:
...@@ -1139,7 +1139,7 @@ class point_space(space): ...@@ -1139,7 +1139,7 @@ class point_space(space):
@para.setter @para.setter
def para(self, x): def para(self, x):
self.paradict['num'] = x self.paradict['num'] = x[0]
def copy(self): def copy(self):
return point_space(num=self.paradict['num'], return point_space(num=self.paradict['num'],
...@@ -1442,21 +1442,34 @@ class point_space(space): ...@@ -1442,21 +1442,34 @@ class point_space(space):
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast(self, x, dtype=None, verbose=False, **kwargs): def cast(self, x, dtype=None, **kwargs):
if dtype is not None: if dtype is not None:
dtype = np.dtype(dtype).type dtype = np.dtype(dtype).type
# If x is a field, extract the data and do a recursive call
if isinstance(x, field):
# Check if the domain matches
if self != x.domain:
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val,
dtype=dtype,
**kwargs)
if self.datamodel in POINT_DISTRIBUTION_STRATEGIES: if self.datamodel in POINT_DISTRIBUTION_STRATEGIES:
return self._cast_to_d2o(x=x, dtype=dtype, verbose=verbose, return self._cast_to_d2o(x=x,
dtype=dtype,
**kwargs) **kwargs)
elif self.datamodel == 'np': elif self.datamodel == 'np':
return self._cast_to_np(x=x, dtype=dtype, verbose=verbose, return self._cast_to_np(x=x,
dtype=dtype,
**kwargs) **kwargs)
else: else:
raise NotImplementedError(about._errors.cstring( raise NotImplementedError(about._errors.cstring(
"ERROR: function is not implemented for given datamodel.")) "ERROR: function is not implemented for given datamodel."))
def _cast_to_d2o(self, x, dtype=None, verbose=False, **kwargs): def _cast_to_d2o(self, x, dtype=None, **kwargs):
""" """
Computes valid field values from a given object, trying Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as to translate the given data into a valid form. Thereby it is as
...@@ -1482,30 +1495,25 @@ class point_space(space): ...@@ -1482,30 +1495,25 @@ class point_space(space):
if dtype is None: if dtype is None:
dtype = self.datatype dtype = self.datatype
# Case 1: x is a field # Case 1: x is a distributed_data_object
if isinstance(x, field):
if verbose:
# Check if the domain matches
if(self != x.domain):
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val, dtype=dtype)
# Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object): if isinstance(x, distributed_data_object):
to_copy = False
# Check the shape # Check the shape
if np.any(x.shape != self.get_shape()): if np.any(x.shape != self.get_shape()):
# Check if at least the number of degrees of freedom is equal # Check if at least the number of degrees of freedom is equal
if x.get_dim() == self.get_dim(): if x.get_dim() == self.get_dim():
# If the number of dof is equal or 1, use np.reshape... # If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush( about.warnings.cflush(
"WARNING: Trying to reshape the data. This operation is " + "WARNING: Trying to reshape the data. This " +
"expensive as it consolidates the full data!\n") "operation is expensive as it consolidates the " +
"full data!\n")
temp = x.get_full_data() temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape()) temp = np.reshape(temp, self.get_shape())
# ... and cast again # ... and cast again
return self.cast(temp, dtype=dtype) return self._cast_to_d2o(temp,
dtype=dtype,
**kwargs)
else: else:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
...@@ -1514,25 +1522,35 @@ class point_space(space): ...@@ -1514,25 +1522,35 @@ class point_space(space):
# Check the datatype # Check the datatype
if x.dtype != dtype: if x.dtype != dtype:
about.warnings.cflush( about.warnings.cflush(
"WARNING: Datatypes are uneqal/of conflicting precision (own: " "WARNING: Datatypes are uneqal/of conflicting precision " +
+ str(dtype) + " <> foreign: " + str(x.dtype) "(own: " + str(dtype) + " <> foreign: " + str(x.dtype) +
+ ") and will be casted! " ") and will be casted! Potential loss of precision!\n")
+ "Potential loss of precision!\n") to_copy = True
temp = x.copy_empty(dtype=dtype)
# Check the distribution_strategy
if x.distribution_strategy != self.datamodel:
to_copy = True
if to_copy:
temp = x.copy_empty(dtype=dtype,
distribution_strategy=self.datamodel)
temp.set_local_data(x.get_local_data()) temp.set_local_data(x.get_local_data())
temp.hermitian = x.hermitian temp.hermitian = x.hermitian
x = temp x = temp
return x return x
# Case 3: x is something else # Case 2: x is something else
# Use general d2o casting # Use general d2o casting
x = distributed_data_object(x, global_shape=self.get_shape(), else:
dtype=dtype) x = distributed_data_object(x,
# Cast the d2o global_shape=self.get_shape(),
return self.cast(x, dtype=dtype) dtype=dtype,
distribution_strategy=self.datamodel)
# Cast the d2o
return self.cast(x, dtype=dtype)
def _cast_to_np(self, x, dtype=None, verbose=False, **kwargs): def _cast_to_np(self, x, dtype=None, **kwargs):
""" """
Computes valid field values from a given object, trying Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as to translate the given data into a valid form. Thereby it is as
...@@ -1558,23 +1576,16 @@ class point_space(space): ...@@ -1558,23 +1576,16 @@ class point_space(space):
if dtype is None: if dtype is None:
dtype = self.datatype dtype = self.datatype
# Case 1: x is a field # Case 1: x is a distributed_data_object
if isinstance(x, field):
if verbose:
# Check if the domain matches
if(self != x.domain):
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val, dtype=dtype)
# Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object): if isinstance(x, distributed_data_object):
# Extract the data # Extract the data
temp = x.get_full_data() temp = x.get_full_data()
# Cast the resulting numpy array again # Cast the resulting numpy array again
return self.cast(temp, dtype=dtype) return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
# Case 2: x is a distributed_data_object
elif isinstance(x, np.ndarray): elif isinstance(x, np.ndarray):
# Check the shape # Check the shape
if np.any(x.shape != self.get_shape()): if np.any(x.shape != self.get_shape()):
...@@ -1583,12 +1594,16 @@ class point_space(space): ...@@ -1583,12 +1594,16 @@ class point_space(space):
# If the number of dof is equal or 1, use np.reshape... # If the number of dof is equal or 1, use np.reshape...
temp = x.reshape(self.get_shape()) temp = x.reshape(self.get_shape())
# ... and cast again # ... and cast again
return self.cast(temp, dtype=dtype) return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
elif x.size == 1: elif x.size == 1:
temp = np.empty(shape=self.get_shape(), temp = np.empty(shape=self.get_shape(),
dtype=dtype) dtype=dtype)
temp[:] = x temp[:] = x
return self.cast(temp, dtype=dtype) return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
else: else:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: Data has incompatible shape!")) "ERROR: Data has incompatible shape!"))
...@@ -1596,14 +1611,15 @@ class point_space(space): ...@@ -1596,14 +1611,15 @@ class point_space(space):
# Check the datatype # Check the datatype
if x.dtype != dtype: if x.dtype != dtype:
about.warnings.cflush( about.warnings.cflush(
"WARNING: Datatypes are uneqal/of conflicting precision (own: " "WARNING: Datatypes are uneqal/of conflicting precision " +
+ str(dtype) + " <> foreign: " + str(x.dtype) " (own: " + str(dtype) + " <> foreign: " + str(x.dtype) +
+ ") and will be casted! " ") and will be casted! Potential loss of precision!\n")
+ "Potential loss of precision!\n")
# Fix the datatype... # Fix the datatype...
temp = x.astype(dtype) temp = x.astype(dtype)
# ... and cast again # ... and cast again
return self.cast(temp, dtype=dtype) return self._cast_to_np(temp,
dtype=dtype,
**kwargs)
return x return x
...@@ -1613,7 +1629,7 @@ class point_space(space): ...@@ -1613,7 +1629,7 @@ class point_space(space):
temp = np.empty(self.get_shape(), dtype=dtype) temp = np.empty(self.get_shape(), dtype=dtype)
if x is not None: if x is not None:
temp[:] = x temp[:] = x
return temp return self._cast_to_np(temp)
def enforce_shape(self, x): def enforce_shape(self, x):
""" """
......
...@@ -2668,33 +2668,6 @@ class _dtype_converter(object): ...@@ -2668,33 +2668,6 @@ class _dtype_converter(object):
""" """
def __init__(self): def __init__(self):
# pre_dict = [
# #[, MPI_CHAR],
# #[, MPI_SIGNED_CHAR],
# #[, MPI_UNSIGNED_CHAR],
# [np.bool_, MPI.BYTE],
# [np.int16, MPI.SHORT],
# [np.uint16, MPI.UNSIGNED_SHORT],
# [np.uint32, MPI.UNSIGNED_INT],
# [np.int32, MPI.INT],
# [np.int, MPI.LONG],
# [np.int_, MPI.LONG],
# [np.int64, MPI.LONG],
# [np.long, MPI.LONG],
# [np.longlong, MPI.LONG_LONG],
# [np.uint64, MPI.UNSIGNED_LONG],
# [np.ulonglong, MPI.UNSIGNED_LONG_LONG],
# [np.int64, MPI.LONG_LONG],
# [np.uint64, MPI.UNSIGNED_LONG_LONG],
# [np.float32, MPI.FLOAT],
# [np.float, MPI.DOUBLE],
# [np.float_, MPI.DOUBLE],
# [np.float64, MPI.DOUBLE],
# [np.float128, MPI.LONG_DOUBLE],
# [np.complex64, MPI.COMPLEX],
# [np.complex, MPI.DOUBLE_COMPLEX],
# [np.complex_, MPI.DOUBLE_COMPLEX],
# [np.complex128, MPI.DOUBLE_COMPLEX]]
pre_dict = [ pre_dict = [
# [, MPI_CHAR], # [, MPI_CHAR],
# [, MPI_SIGNED_CHAR], # [, MPI_SIGNED_CHAR],
......
...@@ -8,30 +8,17 @@ Created on Thu Apr 2 21:29:30 2015 ...@@ -8,30 +8,17 @@ Created on Thu Apr 2 21:29:30 2015
import numpy as np import numpy as np
from keepers import about from keepers import about
"""
def paradict_getter(space_instance):
paradict_dictionary = {
str(space().__class__) : _space_paradict,
str(point_space((2)).__class__) : _point_space_paradict,
str(rg_space((2)).__class__) : _rg_space_paradict,
str(nested_space([point_space(2), point_space(2)]).__class__) : _nested_space_paradict,
str(lm_space(1).__class__) : _lm_space_paradict,
str(gl_space(2).__class__) : _gl_space_paradict,
str(hp_space(1).__class__) : _hp_space_paradict,
}
return paradict_dictionary[str(space_instance.__class__)]()
"""
class space_paradict(object): class space_paradict(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.parameters = {} self.parameters = {}
for key in kwargs: for key in kwargs:
self[key] = kwargs[key] self[key] = kwargs[key]
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, self.__class__) return (isinstance(other, self.__class__) and
and self.__dict__ == other.__dict__) self.__dict__ == other.__dict__)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
...@@ -41,303 +28,339 @@ class space_paradict(object): ...@@ -41,303 +28,339 @@ class space_paradict(object):
def __setitem__(self, key, arg): def __setitem__(self, key, arg):
if(np.isscalar(arg)): if(np.isscalar(arg)):
arg = np.array([arg],dtype=np.int) arg = np.array([arg], dtype=np.int)
else: else:
arg = np.array(arg,dtype=np.int) arg = np.array(arg, dtype=np.int)
self.parameters.__setitem__(key, arg) self.parameters.__setitem__(key, arg)
def __getitem__(self, key): def __getitem__(self, key):
return self.parameters.__getitem__(key) return self.parameters.__getitem__(key)
class point_space_paradict(space_paradict): class point_space_paradict(space_paradict):
def __setitem__(self, key, arg): def __setitem__(self, key, arg):
if key is not 'num': if key is not 'num':
raise ValueError(about._errors.cstring("ERROR: Unsupported point_space parameter")) raise ValueError(about._errors.cstring(
temp = np.array(arg, dtype=int).flatten()[0] "ERROR: Unsupported point_space parameter"))
if not np.isscalar(arg):
raise ValueError(about._errors.cstring(
"ERROR: 'num' parameter must be scalar. Got: " + str(arg)))
if abs(arg) != arg:
raise ValueError(about._errors.cstring(
"ERROR: 'num' parameter must be positive. Got: " + str(arg)))
temp = np.int(arg)
self.parameters.__setitem__(key, temp) self.parameters.__setitem__(key, temp)
class rg_space_paradict(space_paradict): class rg_space_paradict(space_paradict):
def __init__(self, num, complexity=2, zerocenter=False): def __init__(self, num, complexity=2, zerocenter=False):
self.ndim = len(np.array(num).flatten()) self.ndim = len(np.array(num).flatten())
space_paradict.__init__(self, num=num, complexity=complexity, zerocenter=zerocenter) space_paradict.__init__(