Commit 25f052c3 authored by csongor's avatar csongor
Browse files

WIP: Field support for multiple spaces.

parent e735aeae
...@@ -911,9 +911,11 @@ class point_space(space): ...@@ -911,9 +911,11 @@ class point_space(space):
'mean': lambda y: getattr(y, 'mean')(axis=axis), 'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis), 'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis), 'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(axis=axis), 'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(
axis=axis),
'argmin': lambda y: getattr(y, 'argmin')(axis=axis), 'argmin': lambda y: getattr(y, 'argmin')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(axis=axis), 'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(
axis=axis),
'argmax': lambda y: getattr(y, 'argmax')(axis=axis), 'argmax': lambda y: getattr(y, 'argmax')(axis=axis),
'conjugate': lambda y: getattr(y, 'conjugate')(), 'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis), 'sum': lambda y: getattr(y, 'sum')(axis=axis),
...@@ -1038,25 +1040,7 @@ class point_space(space): ...@@ -1038,25 +1040,7 @@ class point_space(space):
return self.calc_weight(mol, power=1) return self.calc_weight(mol, power=1)
def cast(self, x=None, dtype=None, **kwargs): def cast(self, x=None, dtype=None, **kwargs):
if dtype is not None: return self._cast_to_d2o(x=x, dtype=dtype, **kwargs)
dtype = np.dtype(dtype)
# If x is a field, extract the data and do a recursive call
if isinstance(x, field):
# Check if the domain matches
if self != x.domain:
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val,
dtype=dtype,
**kwargs)
else:
return self._cast_to_d2o(x=x,
dtype=dtype,
**kwargs)
def _cast_to_d2o(self, x, dtype=None, **kwargs): def _cast_to_d2o(self, x, dtype=None, **kwargs):
""" """
...@@ -1081,6 +1065,8 @@ class point_space(space): ...@@ -1081,6 +1065,8 @@ class point_space(space):
Whether the method should raise a warning if information is Whether the method should raise a warning if information is
lost during casting (default: False). lost during casting (default: False).
""" """
if dtype is not None:
dtype = np.dtype(dtype)
if dtype is None: if dtype is None:
dtype = self.dtype dtype = self.dtype
...@@ -1357,8 +1343,8 @@ class point_space(space): ...@@ -1357,8 +1343,8 @@ class point_space(space):
processed_std = std processed_std = std
else: else:
try: try:
processed_std = sample.distributor.\ processed_std = sample.distributor. \
extract_local_data(std) extract_local_data(std)
except(AttributeError): except(AttributeError):
processed_std = std processed_std = std
...@@ -1375,8 +1361,6 @@ class point_space(space): ...@@ -1375,8 +1361,6 @@ class point_space(space):
vmax=arg['vmax'])) vmax=arg['vmax']))
return sample return sample
def calc_weight(self, x, power=1): def calc_weight(self, x, power=1):
""" """
Weights a given array of field values with the pixel volumes (not Weights a given array of field values with the pixel volumes (not
...@@ -1575,7 +1559,7 @@ class point_space(space): ...@@ -1575,7 +1559,7 @@ class point_space(space):
ax0 = fig.add_axes([0.12, 0.12, 0.82, 0.76]) ax0 = fig.add_axes([0.12, 0.12, 0.82, 0.76])
xaxes = np.arange(self.para[0], dtype=np.dtype('int')) xaxes = np.arange(self.para[0], dtype=np.dtype('int'))
if(norm == "log")and(vmin <= 0): if (norm == "log") and (vmin <= 0):
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: nonpositive value(s).")) "ERROR: nonpositive value(s)."))
...@@ -1741,8 +1725,9 @@ class field(object): ...@@ -1741,8 +1725,9 @@ class field(object):
""" """
def __init__(self, domain=None, val=None, codomain=None, ishape=None, def __init__(self, domain=None, val=None, codomain=None,
copy=False, **kwargs): copy=False, dtype=np.dtype('float64'), datamodel='not',
**kwargs):
""" """
Sets the attributes for a field class instance. Sets the attributes for a field class instance.
...@@ -1771,32 +1756,31 @@ class field(object): ...@@ -1771,32 +1756,31 @@ class field(object):
self._init_from_field(f=val, self._init_from_field(f=val,
domain=domain, domain=domain,
codomain=codomain, codomain=codomain,
ishape=ishape,
copy=copy, copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs) **kwargs)
else: else:
self._init_from_array(val=val, self._init_from_array(val=val,
domain=domain, domain=domain,
codomain=codomain, codomain=codomain,
ishape=ishape,
copy=copy, copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs) **kwargs)
def _init_from_field(self, f, domain, codomain, ishape, copy, **kwargs): def _init_from_field(self, f, domain, codomain, copy, dtype, datamodel,
**kwargs):
# check domain # check domain
if domain is None: if domain is None:
domain = f.domain domain = f.domain
# check codomain # check codomain
if codomain is None: if codomain is None:
if domain.check_codomain(f.codomain): if self.check_codomain(domain, f.codomain):
codomain = f.codomain codomain = f.codomain
else: else:
codomain = domain.get_codomain() codomain = self.get_codomain(domain)
# check for ishape
if ishape is None:
ishape = f.ishape
# Check if the given field lives in a space which is compatible to the # Check if the given field lives in a space which is compatible to the
# given domain # given domain
...@@ -1808,51 +1792,78 @@ class field(object): ...@@ -1808,51 +1792,78 @@ class field(object):
self._init_from_array(domain=domain, self._init_from_array(domain=domain,
val=f.val, val=f.val,
codomain=codomain, codomain=codomain,
ishape=ishape,
copy=copy, copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs) **kwargs)
def _init_from_array(self, val, domain, codomain, ishape, copy, **kwargs): def _init_from_array(self, val, domain, codomain, copy, dtype, datamodel,
**kwargs):
if dtype is None:
dtype = np.dtype('float64')
self.dtype = dtype
if datamodel not in DISTRIBUTION_STRATEGIES['global']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = \
gc['default_distribution_strategy']
else:
self.datamodel = datamodel
# check domain # check domain
if not isinstance(domain, space): self.domain = self.check_valid_domain(domain=domain)
raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space."))
self.domain = domain
# check codomain # check codomain
if codomain is None: if codomain is None:
codomain = domain.get_codomain() codomain = self.get_codomain(domain)
elif not self.domain.check_codomain(codomain): elif not self.check_codomain(domain=domain, codomain=codomain):
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: The given codomain is not compatible to the domain.")) "ERROR: The given codomain is not compatible to the domain."))
self.codomain = codomain self.codomain = codomain
if ishape is not None:
ishape = tuple(np.array(ishape, dtype=np.uint).flatten())
elif val is not None:
try:
if val.dtype.type == np.object_:
ishape = val.shape
else:
ishape = ()
except(AttributeError):
try:
ishape = val.ishape
except(AttributeError):
ishape = ()
else:
ishape = ()
self.ishape = ishape
if val is None: if val is None:
if kwargs == {}: if kwargs == {}:
val = self._map(lambda: self.domain.cast(0.)) val = self._map(lambda: self.cast((0,)))
else: else:
val = self._map(lambda: self.domain.get_random_values( val = self._map(lambda: self.domain.get_random_values(
codomain=self.codomain, codomain=self.codomain,
**kwargs)) **kwargs))
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def check_valid_domain(self, domain):
if not isinstance(domain, np.ndarray):
raise TypeError(about._errors.cstring(
"ERROR: The given domain is not a list."))
for d in domain:
if not isinstance(d, space):
raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space."))
elif d.dtype != self.dtype:
raise AttributeError(about._errors.cstring(
"ERROR: The dtype of a space in the domain missmatches "
"the field's dtype."))
elif d.datamodel != self.datamodel:
raise AttributeError(about._errors.cstring(
"ERROR: The datamodel of a space in the domain missmatches "
"the field's datamodel."))
return domain
def check_codomain(self, domain, codomain):
if codomain is None:
return False
if domain.shape == codomain.shape:
return np.all(map((lambda d, c: d._check_codomain(c)), domain,
codomain))
else:
return False
def get_codomain(self, domain):
if domain.shape == (1,):
return np.array(domain[0].get_codomain())
else:
# TODO implement for multiple domain get_codomain need
# calc_transform
return np.empty((0,))
def __len__(self): def __len__(self):
return int(self.get_dim(split=True)[0]) return int(self.get_dim(split=True)[0])
...@@ -2010,29 +2021,127 @@ class field(object): ...@@ -2010,29 +2021,127 @@ class field(object):
def get_ishape(self): def get_ishape(self):
return self.ishape return self.ishape
def get_global_shape(self):
global_shape = np.sum([space.get_shape() for space in self.domain])
if isinstance(global_shape, tuple):
return global_shape
else:
return ()
def _map(self, function, *args): def _map(self, function, *args):
return utilities.field_map(self.ishape, function, *args) return utilities.field_map(self.get_global_shape(), function, *args)
def cast(self, x=None, ishape=None): def cast(self, x=None, dtype=None):
if ishape is None: if dtype is not None:
ishape = self.ishape dtype = np.dtype(dtype)
casted_x = self._cast_to_ishape(x, ishape=ishape) if dtype is None:
if ishape == (): dtype = self.dtype
casted_x = self._cast_to_shape(x)
if self.get_global_shape() == ():
return self.domain.cast(casted_x) return self.domain.cast(casted_x)
else: else:
return self._map(lambda z: self.domain.cast(z), return self._map(lambda z: self.domain.cast(z),
casted_x) casted_x)
def _cast_to_ishape(self, x, ishape=None): def _cast_to_d2o(self, x, dtype=None, **kwargs):
if ishape is None: """
ishape = self.ishape Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is None:
dtype = self.dtype
# Case 1: x is a distributed_data_object
if isinstance(x, distributed_data_object):
to_copy = False
# Check the shape
if np.any(np.array(x.shape) != np.array(self.get_shape())):
# Check if at least the number of degrees of freedom is equal
if x.get_dim() == self.get_dim():
try:
temp = x.copy_empty(global_shape=self.get_shape())
temp.set_local_data(x.get_local_data(), copy=False)
except:
# If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(
"WARNING: Trying to reshape the data. This " +
"operation is expensive as it consolidates the " +
"full data!\n")
temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape())
# ... and cast again
return self._cast_to_d2o(temp,
dtype=dtype,
**kwargs)
else:
raise ValueError(about._errors.cstring(
"ERROR: Data has incompatible shape!"))
# Check the dtype
if x.dtype != dtype:
if x.dtype > dtype:
about.warnings.cflush(
"WARNING: Datatypes are of conflicting precision " +
"(own: " + str(dtype) + " <> foreign: " +
str(x.dtype) + ") and will be casted! Potential " +
"loss of precision!\n")
to_copy = True
# Check the distribution_strategy
if x.distribution_strategy != self.datamodel:
to_copy = True
if to_copy:
temp = x.copy_empty(dtype=dtype,
distribution_strategy=self.datamodel)
temp.set_data(to_key=(slice(None),),
data=x,
from_key=(slice(None),))
temp.hermitian = x.hermitian
x = temp
return x
# Case 2: x is something else
# Use general d2o casting
else:
x = distributed_data_object(x,
global_shape=self.get_shape(),
dtype=dtype,
distribution_strategy=self.datamodel)
# Cast the d2o
return self.cast(x, dtype=dtype)
def _cast_to_shape(self, x):
if isinstance(x, field): if isinstance(x, field):
x = x.get_val() x = x.get_val()
if ishape == ():
global_shape = self.get_global_shape()
if global_shape == ():
casted_x = self._cast_to_scalar_helper(x) casted_x = self._cast_to_scalar_helper(x)
else: else:
casted_x = self._cast_to_tensor_helper(x, ishape) casted_x = self._cast_to_tensor_helper(x, shape=global_shape)
return casted_x return casted_x
def _cast_to_scalar_helper(self, x): def _cast_to_scalar_helper(self, x):
...@@ -2068,26 +2177,26 @@ class field(object): ...@@ -2068,26 +2177,26 @@ class field(object):
# In all other cases, cast x directly # In all other cases, cast x directly
return x return x
def _cast_to_tensor_helper(self, x, ishape=None): def _cast_to_tensor_helper(self, x, shape=None):
if ishape is None: if shape is None:
ishape = self.ishape shape = self.get_global_shape()
# Check if x is a container of proper length # Check if x is a container of proper length
# containing something which will then checked by the domain-space # containing something which will then checked by the domain-space
x_shape = np.shape(x) x_shape = np.shape(x)
self_shape = self.domain.get_shape() self_shape = self.get_global_shape()
try: try:
container_Q = (x.dtype.type == np.object_) container_Q = (x.dtype.type == np.object_)
except(AttributeError): except(AttributeError):
container_Q = False container_Q = False
if container_Q: if container_Q:
if x_shape == ishape: if x_shape == shape:
return x return x
elif x_shape == ishape[:len(x_shape)]: elif x_shape == shape[:len(x_shape)]:
return x.reshape(x_shape + return x.reshape(x_shape +
(1,) * (len(ishape) - len(x_shape))) (1,) * (len(shape) - len(x_shape)))
# Slow track: x could be a pure ndarray # Slow track: x could be a pure ndarray
...@@ -2095,11 +2204,11 @@ class field(object): ...@@ -2095,11 +2204,11 @@ class field(object):
# 1: There are cases where np.shape will only find the container # 1: There are cases where np.shape will only find the container
# although it was no np.object array; e.g. for [a,1]. # although it was no np.object array; e.g. for [a,1].
# 2: The overall shape is already the right one # 2: The overall shape is already the right one
if x_shape == ishape or x_shape == (ishape + self_shape): if x_shape == shape or x_shape == (shape + self_shape):
# Iterate over the outermost dimension and cast the inner spaces # Iterate over the outermost dimension and cast the inner spaces
result = np.empty(ishape, dtype=np.object) result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(ishape)): for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, ishape) ii = np.unravel_index(i, shape)
try: try:
result[ii] = x[ii] result[ii] = x[ii]
except(TypeError): except(TypeError):
...@@ -2112,16 +2221,16 @@ class field(object): ...@@ -2112,16 +2221,16 @@ class field(object):
# Check if the input has shape (1, self.domain.shape) # Check if the input has shape (1, self.domain.shape)
# Iterate over the outermost dimension and cast the inner spaces # Iterate over the outermost dimension and cast the inner spaces
elif x_shape == ((1,) + self_shape): elif x_shape == ((1,) + self_shape):
result = np.empty(ishape, dtype=np.object) result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(ishape)): for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, ishape) ii = np.unravel_index(i, shape)
result[ii] = x[0] result[ii] = x[0]
# Case 4: fallback: try to cast x with self.domain # Case 4: fallback: try to cast x with self.domain
else: # Iterate over the outermost dimension and cast the inner spaces else: # Iterate over the outermost dimension and cast the inner spaces
result = np.empty(ishape, dtype=np.object) result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(ishape)): for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, ishape) ii = np.unravel_index(i, shape)
result[ii] = x result[ii] = x
return result return result
...@@ -2903,4 +3012,4 @@ class field(object): ...@@ -2903,4 +3012,4 @@ class field(object):
class EmptyField(field): class EmptyField(field):
def __init__(self): def __init__(self):
pass pass
\ No newline at end of file
...@@ -56,11 +56,11 @@ all_hp_datatypes = [np.dtype('float64')] ...@@ -56,11 +56,11 @@ all_hp_datatypes = [np.dtype('float64')]
############################################################################### ###############################################################################
DATAMODELS = {} DATAMODELS = {}
DATAMODELS['point_space'] = ['np'] + POINT_DISTRIBUTION_STRATEGIES DATAMODELS['point_space'] = POINT_DISTRIBUTION_STRATEGIES
DATAMODELS['rg_space'] = ['np'] + RG_DISTRIBUTION_STRATEGIES DATAMODELS['rg_space'] = RG_DISTRIBUTION_STRATEGIES
DATAMODELS['lm_space'] = ['np'] + LM_DISTRIBUTION_STRATEGIES DATAMODELS['lm_space'] = LM_DISTRIBUTION_STRATEGIES
DATAMODELS['gl_space'] = ['np'] + GL_DISTRIBUTION_STRATEGIES DATAMODELS['gl_space'] = GL_DISTRIBUTION_STRATEGIES
DATAMODELS['hp_space'] = ['np'] + HP_DISTRIBUTION_STRATEGIES DATAMODELS['hp_space'] = HP_DISTRIBUTION_STRATEGIES
############################################################################### ###############################################################################
...@@ -110,10 +110,9 @@ class Test_field_init(unittest.TestCase): ...@@ -110,10 +110,9 @@ class Test_field_init(unittest.TestCase):
@parameterized.expand(space_list) @parameterized.expand(space_list)
def test_successfull_init_and_attributes(self, s): def test_successfull_init_and_attributes(self, s):
s = s[0] f = field(domain=np.array([s]), dtype=s.dtype)
f = field(s) assert(f.domain[0] is s)
assert(f.domain is s) assert(s.check_codomain(f.codomain[0]))
assert(s.check_codomain(f.codomain))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment