Commit 01f48ddd authored by csongor's avatar csongor
Browse files

Fix _axis_list functionality

parent ee7a739e
......@@ -103,7 +103,7 @@ class field(object):
"""
def __init__(self, domain=None, val=None, codomain=None,
comm=gc['default_comm'], copy=False, dtype=np.dtype('float64'),
comm=gc['default_comm'], copy=False, dtype=None,
datamodel='fftw', **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -214,21 +214,24 @@ class field(object):
if kwargs == {}:
val = self.cast(0)
else:
val = map(lambda z: self.get_random_values(domain = self.domain,
codomain=z, **kwargs), self.codomain)
val = map(lambda z: self.get_random_values(domain=self.domain,
codomain=z,
**kwargs),
self.codomain)
self.set_val(new_val=val, copy=copy)
def _get_dtype_from_domain(self, domain=None):
if domain is None:
domain = self.domain
dtype_tuple = tuple(space.dtype for space in domain)
dtype = np.result_type(dtype_tuple)
dtype_tuple = tuple(np.dtype(space.dtype) for space in domain)
dtype = reduce(lambda x,y: np.result_type(x,y), dtype_tuple)
return dtype
def _get_axis_list_from_domain(self, domain=None):
if domain is None:
domain = self.domain
axis_list = [space.get_shape() for space in domain]
axis_list = [tuple(ind for i in range(len(space.get_shape()))) for
  • Indices should be incremented:

    al = ((1,2,3),(4,5),(6,7),(8,))
    sub_space.calc_transform(data, axes=al[space_index])
    Edited by Theo Steininger
Please register or sign in to reply
ind, space in enumerate(domain)]
return axis_list
def _parse_comm(self, comm):
......@@ -281,7 +284,6 @@ class field(object):
self.codomain = codomain
return codomain
def get_random_values(self, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'enforce_power'."))
......@@ -380,8 +382,8 @@ class field(object):
def get_shape(self):
if len(self.domain) > 1:
global_shape = reduce(lambda x, y: x.get_shape() + y.get_shape(),
self.domain)
shape_tuple = tuple(space.get_shape() for space in self.domain)
global_shape = reduce(lambda x,y: x+y, shape_tuple)
else:
global_shape = self.domain[0].get_shape()
......
......@@ -109,7 +109,7 @@ def generate_space_with_size(name, num):
'rg_space': rg_space((num, num)),
'lm_space': lm_space(mmax=num+1, lmax=num+1),
'hp_space': hp_space(num),
'gl_space': gl_space(nlat=num, nlon=num),
'gl_space': gl_space(nlat=num, nlon=2*num-1),
}
return space_dict[name]
......@@ -156,7 +156,7 @@ class Test_field_init2(unittest.TestCase):
assert (s.check_codomain(f.codomain[0]))
assert (s.get_shape() == f.get_shape())
class Test_field_multiple_init(unittest.TestCase):
class Test_field_multiple_rg_init(unittest.TestCase):
@parameterized.expand(
itertools.product([(1,)],
[True],
......@@ -182,6 +182,30 @@ class Test_field_multiple_init(unittest.TestCase):
assert (s2.check_codomain(f.codomain[1]))
assert (s1.get_shape() + s2.get_shape() == f.get_shape())
class Test_field_multiple_init(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, point_like_spaces, [4]),
testcase_func_name=custom_name_func)
def test_multiple_space_init(self, space1, space2, shape):
s1 = generate_space_with_size(space1, shape)
s2 = generate_space_with_size(space2, shape)
f = field(domain=(s1, s2))
assert (f.domain[0] is s1)
assert (f.domain[1] is s2)
assert (s1.check_codomain(f.codomain[0]))
assert (s2.check_codomain(f.codomain[1]))
assert (s1.get_shape() + s2.get_shape() == f.get_shape())
s3 = generate_space_with_size('hp_space',shape)
f = field(domain=(s1, s2, s3))
assert (f.domain[0] is s1)
assert (f.domain[1] is s2)
assert (f.domain[2] is s3)
assert (s1.check_codomain(f.codomain[0]))
assert (s2.check_codomain(f.codomain[1]))
assert (s3.check_codomain(f.codomain[2]))
assert (s1.get_shape() + s2.get_shape() + s3.get_shape() ==
f.get_shape())
class Test_axis(unittest.TestCase):
@parameterized.expand(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment