Commit 25f052c3 authored by csongor's avatar csongor
Browse files

WIP: Field support for multiple spaces.

parent e735aeae
......@@ -911,9 +911,11 @@ class point_space(space):
'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(axis=axis),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(
axis=axis),
'argmin': lambda y: getattr(y, 'argmin')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(
axis=axis),
'argmax': lambda y: getattr(y, 'argmax')(axis=axis),
'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis),
......@@ -1038,25 +1040,7 @@ class point_space(space):
return self.calc_weight(mol, power=1)
def cast(self, x=None, dtype=None, **kwargs):
if dtype is not None:
dtype = np.dtype(dtype)
# If x is a field, extract the data and do a recursive call
if isinstance(x, field):
# Check if the domain matches
if self != x.domain:
about.warnings.cflush(
"WARNING: Getting data from foreign domain!")
# Extract the data, whatever it is, and cast it again
return self.cast(x.val,
dtype=dtype,
**kwargs)
else:
return self._cast_to_d2o(x=x,
dtype=dtype,
**kwargs)
return self._cast_to_d2o(x=x, dtype=dtype, **kwargs)
def _cast_to_d2o(self, x, dtype=None, **kwargs):
"""
......@@ -1081,6 +1065,8 @@ class point_space(space):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is not None:
dtype = np.dtype(dtype)
if dtype is None:
dtype = self.dtype
......@@ -1357,8 +1343,8 @@ class point_space(space):
processed_std = std
else:
try:
processed_std = sample.distributor.\
extract_local_data(std)
processed_std = sample.distributor. \
extract_local_data(std)
except(AttributeError):
processed_std = std
......@@ -1375,8 +1361,6 @@ class point_space(space):
vmax=arg['vmax']))
return sample
def calc_weight(self, x, power=1):
"""
Weights a given array of field values with the pixel volumes (not
......@@ -1575,7 +1559,7 @@ class point_space(space):
ax0 = fig.add_axes([0.12, 0.12, 0.82, 0.76])
xaxes = np.arange(self.para[0], dtype=np.dtype('int'))
if(norm == "log")and(vmin <= 0):
if (norm == "log") and (vmin <= 0):
raise ValueError(about._errors.cstring(
"ERROR: nonpositive value(s)."))
......@@ -1741,8 +1725,9 @@ class field(object):
"""
def __init__(self, domain=None, val=None, codomain=None, ishape=None,
copy=False, **kwargs):
def __init__(self, domain=None, val=None, codomain=None,
copy=False, dtype=np.dtype('float64'), datamodel='not',
**kwargs):
"""
Sets the attributes for a field class instance.
......@@ -1771,32 +1756,31 @@ class field(object):
self._init_from_field(f=val,
domain=domain,
codomain=codomain,
ishape=ishape,
copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs)
else:
self._init_from_array(val=val,
domain=domain,
codomain=codomain,
ishape=ishape,
copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs)
def _init_from_field(self, f, domain, codomain, ishape, copy, **kwargs):
def _init_from_field(self, f, domain, codomain, copy, dtype, datamodel,
**kwargs):
# check domain
if domain is None:
domain = f.domain
# check codomain
if codomain is None:
if domain.check_codomain(f.codomain):
if self.check_codomain(domain, f.codomain):
codomain = f.codomain
else:
codomain = domain.get_codomain()
# check for ishape
if ishape is None:
ishape = f.ishape
codomain = self.get_codomain(domain)
# Check if the given field lives in a space which is compatible to the
# given domain
......@@ -1808,51 +1792,78 @@ class field(object):
self._init_from_array(domain=domain,
val=f.val,
codomain=codomain,
ishape=ishape,
copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs)
def _init_from_array(self, val, domain, codomain, ishape, copy, **kwargs):
def _init_from_array(self, val, domain, codomain, copy, dtype, datamodel,
**kwargs):
if dtype is None:
dtype = np.dtype('float64')
self.dtype = dtype
if datamodel not in DISTRIBUTION_STRATEGIES['global']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = \
gc['default_distribution_strategy']
else:
self.datamodel = datamodel
# check domain
if not isinstance(domain, space):
raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space."))
self.domain = domain
self.domain = self.check_valid_domain(domain=domain)
# check codomain
if codomain is None:
codomain = domain.get_codomain()
elif not self.domain.check_codomain(codomain):
codomain = self.get_codomain(domain)
elif not self.check_codomain(domain=domain, codomain=codomain):
raise ValueError(about._errors.cstring(
"ERROR: The given codomain is not compatible to the domain."))
self.codomain = codomain
if ishape is not None:
ishape = tuple(np.array(ishape, dtype=np.uint).flatten())
elif val is not None:
try:
if val.dtype.type == np.object_:
ishape = val.shape
else:
ishape = ()
except(AttributeError):
try:
ishape = val.ishape
except(AttributeError):
ishape = ()
else:
ishape = ()
self.ishape = ishape
if val is None:
if kwargs == {}:
val = self._map(lambda: self.domain.cast(0.))
val = self._map(lambda: self.cast((0,)))
else:
val = self._map(lambda: self.domain.get_random_values(
codomain=self.codomain,
**kwargs))
self.set_val(new_val=val, copy=copy)
def check_valid_domain(self, domain):
if not isinstance(domain, np.ndarray):
raise TypeError(about._errors.cstring(
"ERROR: The given domain is not a list."))
for d in domain:
if not isinstance(d, space):
raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space."))
elif d.dtype != self.dtype:
raise AttributeError(about._errors.cstring(
"ERROR: The dtype of a space in the domain missmatches "
"the field's dtype."))
elif d.datamodel != self.datamodel:
raise AttributeError(about._errors.cstring(
"ERROR: The datamodel of a space in the domain missmatches "
"the field's datamodel."))
return domain
def check_codomain(self, domain, codomain):
if codomain is None:
return False
if domain.shape == codomain.shape:
return np.all(map((lambda d, c: d._check_codomain(c)), domain,
codomain))
else:
return False
def get_codomain(self, domain):
if domain.shape == (1,):
return np.array(domain[0].get_codomain())
else:
# TODO implement for multiple domain get_codomain need
# calc_transform
return np.empty((0,))
def __len__(self):
return int(self.get_dim(split=True)[0])
......@@ -2010,29 +2021,127 @@ class field(object):
def get_ishape(self):
return self.ishape
def get_global_shape(self):
global_shape = np.sum([space.get_shape() for space in self.domain])
if isinstance(global_shape, tuple):
return global_shape
else:
return ()
def _map(self, function, *args):
return utilities.field_map(self.ishape, function, *args)
return utilities.field_map(self.get_global_shape(), function, *args)
def cast(self, x=None, ishape=None):
if ishape is None:
ishape = self.ishape
casted_x = self._cast_to_ishape(x, ishape=ishape)
if ishape == ():
def cast(self, x=None, dtype=None):
if dtype is not None:
dtype = np.dtype(dtype)
if dtype is None:
dtype = self.dtype
casted_x = self._cast_to_shape(x)
if self.get_global_shape() == ():
return self.domain.cast(casted_x)
else:
return self._map(lambda z: self.domain.cast(z),
casted_x)
def _cast_to_ishape(self, x, ishape=None):
if ishape is None:
ishape = self.ishape
def _cast_to_d2o(self, x, dtype=None, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is None:
dtype = self.dtype
# Case 1: x is a distributed_data_object
if isinstance(x, distributed_data_object):
to_copy = False
# Check the shape
if np.any(np.array(x.shape) != np.array(self.get_shape())):
# Check if at least the number of degrees of freedom is equal
if x.get_dim() == self.get_dim():
try:
temp = x.copy_empty(global_shape=self.get_shape())
temp.set_local_data(x.get_local_data(), copy=False)
except:
# If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(
"WARNING: Trying to reshape the data. This " +
"operation is expensive as it consolidates the " +
"full data!\n")
temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape())
# ... and cast again
return self._cast_to_d2o(temp,
dtype=dtype,
**kwargs)
else:
raise ValueError(about._errors.cstring(
"ERROR: Data has incompatible shape!"))
# Check the dtype
if x.dtype != dtype:
if x.dtype > dtype:
about.warnings.cflush(
"WARNING: Datatypes are of conflicting precision " +
"(own: " + str(dtype) + " <> foreign: " +
str(x.dtype) + ") and will be casted! Potential " +
"loss of precision!\n")
to_copy = True
# Check the distribution_strategy
if x.distribution_strategy != self.datamodel:
to_copy = True
if to_copy:
temp = x.copy_empty(dtype=dtype,
distribution_strategy=self.datamodel)
temp.set_data(to_key=(slice(None),),
data=x,
from_key=(slice(None),))
temp.hermitian = x.hermitian
x = temp
return x
# Case 2: x is something else
# Use general d2o casting
else:
x = distributed_data_object(x,
global_shape=self.get_shape(),
dtype=dtype,
distribution_strategy=self.datamodel)
# Cast the d2o
return self.cast(x, dtype=dtype)
def _cast_to_shape(self, x):
if isinstance(x, field):
x = x.get_val()
if ishape == ():
global_shape = self.get_global_shape()
if global_shape == ():
casted_x = self._cast_to_scalar_helper(x)
else:
casted_x = self._cast_to_tensor_helper(x, ishape)
casted_x = self._cast_to_tensor_helper(x, shape=global_shape)
return casted_x
def _cast_to_scalar_helper(self, x):
......@@ -2068,26 +2177,26 @@ class field(object):
# In all other cases, cast x directly
return x
def _cast_to_tensor_helper(self, x, ishape=None):
if ishape is None:
ishape = self.ishape
def _cast_to_tensor_helper(self, x, shape=None):
if shape is None:
shape = self.get_global_shape()
# Check if x is a container of proper length
# containing something which will then checked by the domain-space
x_shape = np.shape(x)
self_shape = self.domain.get_shape()
self_shape = self.get_global_shape()
try:
container_Q = (x.dtype.type == np.object_)
except(AttributeError):
container_Q = False
if container_Q:
if x_shape == ishape:
if x_shape == shape:
return x
elif x_shape == ishape[:len(x_shape)]:
elif x_shape == shape[:len(x_shape)]:
return x.reshape(x_shape +
(1,) * (len(ishape) - len(x_shape)))
(1,) * (len(shape) - len(x_shape)))
# Slow track: x could be a pure ndarray
......@@ -2095,11 +2204,11 @@ class field(object):
# 1: There are cases where np.shape will only find the container
# although it was no np.object array; e.g. for [a,1].
# 2: The overall shape is already the right one
if x_shape == ishape or x_shape == (ishape + self_shape):
if x_shape == shape or x_shape == (shape + self_shape):
# Iterate over the outermost dimension and cast the inner spaces
result = np.empty(ishape, dtype=np.object)
for i in xrange(np.prod(ishape)):
ii = np.unravel_index(i, ishape)
result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, shape)
try:
result[ii] = x[ii]
except(TypeError):
......@@ -2112,16 +2221,16 @@ class field(object):
# Check if the input has shape (1, self.domain.shape)
# Iterate over the outermost dimension and cast the inner spaces
elif x_shape == ((1,) + self_shape):
result = np.empty(ishape, dtype=np.object)
for i in xrange(np.prod(ishape)):
ii = np.unravel_index(i, ishape)
result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, shape)
result[ii] = x[0]
# Case 4: fallback: try to cast x with self.domain
else: # Iterate over the outermost dimension and cast the inner spaces
result = np.empty(ishape, dtype=np.object)
for i in xrange(np.prod(ishape)):
ii = np.unravel_index(i, ishape)
result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, shape)
result[ii] = x
return result
......@@ -2903,4 +3012,4 @@ class field(object):
class EmptyField(field):
def __init__(self):
pass
\ No newline at end of file
pass
......@@ -56,11 +56,11 @@ all_hp_datatypes = [np.dtype('float64')]
###############################################################################
DATAMODELS = {}
DATAMODELS['point_space'] = ['np'] + POINT_DISTRIBUTION_STRATEGIES
DATAMODELS['rg_space'] = ['np'] + RG_DISTRIBUTION_STRATEGIES
DATAMODELS['lm_space'] = ['np'] + LM_DISTRIBUTION_STRATEGIES
DATAMODELS['gl_space'] = ['np'] + GL_DISTRIBUTION_STRATEGIES
DATAMODELS['hp_space'] = ['np'] + HP_DISTRIBUTION_STRATEGIES
DATAMODELS['point_space'] = POINT_DISTRIBUTION_STRATEGIES
DATAMODELS['rg_space'] = RG_DISTRIBUTION_STRATEGIES
DATAMODELS['lm_space'] = LM_DISTRIBUTION_STRATEGIES
DATAMODELS['gl_space'] = GL_DISTRIBUTION_STRATEGIES
DATAMODELS['hp_space'] = HP_DISTRIBUTION_STRATEGIES
###############################################################################
......@@ -110,10 +110,9 @@ class Test_field_init(unittest.TestCase):
@parameterized.expand(space_list)
def test_successfull_init_and_attributes(self, s):
s = s[0]
f = field(s)
assert(f.domain is s)
assert(s.check_codomain(f.codomain))
f = field(domain=np.array([s]), dtype=s.dtype)
assert(f.domain[0] is s)
assert(s.check_codomain(f.codomain[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment