Commit 89729ee5 authored by csongor's avatar csongor
Browse files

Fix _axis_list functionality

parent 01f48ddd
......@@ -224,14 +224,20 @@ class field(object):
if domain is None:
domain = self.domain
dtype_tuple = tuple(np.dtype(space.dtype) for space in domain)
dtype = reduce(lambda x,y: np.result_type(x,y), dtype_tuple)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype
def _get_axis_list_from_domain(self, domain=None):
if domain is None:
domain = self.domain
axis_list = [tuple(ind for i in range(len(space.get_shape()))) for
ind, space in enumerate(domain)]
i = 0
axis_list = []
for space in domain:
l = []
for j in range(len(space.get_shape())):
l += [i]
i += 1
axis_list += [tuple(l)]
return axis_list
def _parse_comm(self, comm):
......@@ -383,7 +389,7 @@ class field(object):
def get_shape(self):
if len(self.domain) > 1:
shape_tuple = tuple(space.get_shape() for space in self.domain)
global_shape = reduce(lambda x,y: x+y, shape_tuple)
global_shape = reduce(lambda x, y: x + y, shape_tuple)
else:
global_shape = self.domain[0].get_shape()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment