Commit d8b9b918 authored by csongor's avatar csongor
Browse files

WIP: Field support for multiple spaces.

parent 25f052c3
......@@ -1994,7 +1994,11 @@ class field(object):
return gotten
def get_shape(self):
return self.domain.get_shape()
global_shape = np.sum([space.get_shape() for space in self.domain])
if isinstance(global_shape, tuple):
return global_shape
else:
return ()
def get_dim(self, split=False):
"""
......@@ -2013,37 +2017,23 @@ class field(object):
Dimension of space.
"""
return self.domain.get_dim(split=split)
return np.prod(np.sum([space.get_shape() for space in self.domain]))
def get_dof(self, split=False):
return self.domain.get_dof(split=split)
def get_ishape(self):
return self.ishape
def get_global_shape(self):
global_shape = np.sum([space.get_shape() for space in self.domain])
if isinstance(global_shape, tuple):
return global_shape
else:
return ()
return np.sum([len(space.get_shape()) for space in self.domain])
def _map(self, function, *args):
return utilities.field_map(self.get_global_shape(), function, *args)
return utilities.field_map(self.get_shape(), function, *args)
def cast(self, x=None, dtype=None):
if dtype is not None:
dtype = np.dtype(dtype)
if dtype is None:
dtype = self.dtype
casted_x = self._cast_to_shape(x)
if self.get_global_shape() == ():
return self.domain.cast(casted_x)
else:
return self._map(lambda z: self.domain.cast(z),
casted_x)
casted_x = self._cast_to_d2o(x, dtype=dtype)
return self._complement_cast(casted_x)
def _cast_to_d2o(self, x, dtype=None, **kwargs):
def _cast_to_d2o(self, x, dtype=None, shape=None, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
......@@ -2066,29 +2056,34 @@ class field(object):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if isinstance(x, field):
x = x.get_val()
if dtype is None:
dtype = self.dtype
if shape is None:
shape = self.get_shape()
# Case 1: x is a distributed_data_object
if isinstance(x, distributed_data_object):
to_copy = False
# Check the shape
if np.any(np.array(x.shape) != np.array(self.get_shape())):
if np.any(np.array(x.shape) != np.array(shape)):
# Check if at least the number of degrees of freedom is equal
if x.get_dim() == self.get_dim():
try:
temp = x.copy_empty(global_shape=self.get_shape())
temp.set_local_data(x.get_local_data(), copy=False)
temp = x.copy_empty(global_shape=shape)
temp.set_local_data(x, copy=False)
except:
# If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(
"WARNING: Trying to reshape the data. This " +
"operation is expensive as it consolidates the " +
"full data!\n")
temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape())
temp = x
temp = np.reshape(temp, shape)
# ... and cast again
return self._cast_to_d2o(temp,
dtype=dtype,
......@@ -2133,108 +2128,10 @@ class field(object):
# Cast the d2o
return self.cast(x, dtype=dtype)
def _cast_to_shape(self, x):
if isinstance(x, field):
x = x.get_val()
global_shape = self.get_global_shape()
if global_shape == ():
casted_x = self._cast_to_scalar_helper(x)
else:
casted_x = self._cast_to_tensor_helper(x, shape=global_shape)
return casted_x
def _cast_to_scalar_helper(self, x):
# if x is already a scalar or does fit directly, return it
self_shape = self.domain.get_shape()
x_shape = np.shape(x)
if np.isscalar(x) or x_shape == self_shape:
return x
# check if the given object is a 'container'
try:
container_Q = (x.dtype.type == np.object_)
except(AttributeError):
container_Q = False
if container_Q:
# extract the first element. This works on 0-d ndarrays, too.
result = x[(0,) * len(x_shape)]
return result
# if x is no container-type, it could be that the needed shape
# for self.domain is encapsulated in x
if x_shape[len(x_shape) - len(self_shape):] == self_shape:
if x_shape[:len(x_shape) - len(self_shape)] != (1,):
about.warnings.cprint(
"WARNING: discarding all internal dimensions " +
"except for the first one.")
result = x
for i in xrange(len(x_shape) - len(self_shape)):
result = result[0]
return result
# In all other cases, cast x directly
def _complement_cast(self, x):
#TODO implement complement cast for multiple spaces.
return x
def _cast_to_tensor_helper(self, x, shape=None):
if shape is None:
shape = self.get_global_shape()
# Check if x is a container of proper length
# containing something which will then checked by the domain-space
x_shape = np.shape(x)
self_shape = self.get_global_shape()
try:
container_Q = (x.dtype.type == np.object_)
except(AttributeError):
container_Q = False
if container_Q:
if x_shape == shape:
return x
elif x_shape == shape[:len(x_shape)]:
return x.reshape(x_shape +
(1,) * (len(shape) - len(x_shape)))
# Slow track: x could be a pure ndarray
# Case 1 and 2:
# 1: There are cases where np.shape will only find the container
# although it was no np.object array; e.g. for [a,1].
# 2: The overall shape is already the right one
if x_shape == shape or x_shape == (shape + self_shape):
# Iterate over the outermost dimension and cast the inner spaces
result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, shape)
try:
result[ii] = x[ii]
except(TypeError):
extracted = x
for j in xrange(len(ii)):
extracted = extracted[ii[j]]
result[ii] = extracted
# Case 3: The overall shape does not match directly.
# Check if the input has shape (1, self.domain.shape)
# Iterate over the outermost dimension and cast the inner spaces
elif x_shape == ((1,) + self_shape):
result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, shape)
result[ii] = x[0]
# Case 4: fallback: try to cast x with self.domain
else: # Iterate over the outermost dimension and cast the inner spaces
result = np.empty(shape, dtype=np.object)
for i in xrange(np.prod(shape)):
ii = np.unravel_index(i, shape)
result[ii] = x
return result
def set_domain(self, new_domain=None, force=False):
"""
Resets the codomain of the field.
......
......@@ -110,7 +110,7 @@ class Test_field_init(unittest.TestCase):
@parameterized.expand(space_list)
def test_successfull_init_and_attributes(self, s):
f = field(domain=np.array([s]), dtype=s.dtype)
f = field(domain=np.array([s]), dtype=s.dtype, datamodel=s.datamodel)
assert(f.domain[0] is s)
assert(s.check_codomain(f.codomain[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment