Commit 01f48ddd authored by csongor's avatar csongor
Browse files

Fix _axis_list functionality

parent ee7a739e
...@@ -103,7 +103,7 @@ class field(object): ...@@ -103,7 +103,7 @@ class field(object):
""" """
def __init__(self, domain=None, val=None, codomain=None, def __init__(self, domain=None, val=None, codomain=None,
comm=gc['default_comm'], copy=False, dtype=np.dtype('float64'), comm=gc['default_comm'], copy=False, dtype=None,
datamodel='fftw', **kwargs): datamodel='fftw', **kwargs):
""" """
Sets the attributes for a field class instance. Sets the attributes for a field class instance.
...@@ -214,21 +214,24 @@ class field(object): ...@@ -214,21 +214,24 @@ class field(object):
if kwargs == {}: if kwargs == {}:
val = self.cast(0) val = self.cast(0)
else: else:
val = map(lambda z: self.get_random_values(domain = self.domain, val = map(lambda z: self.get_random_values(domain=self.domain,
codomain=z, **kwargs), self.codomain) codomain=z,
**kwargs),
self.codomain)
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def _get_dtype_from_domain(self, domain=None): def _get_dtype_from_domain(self, domain=None):
if domain is None: if domain is None:
domain = self.domain domain = self.domain
dtype_tuple = tuple(space.dtype for space in domain) dtype_tuple = tuple(np.dtype(space.dtype) for space in domain)
dtype = np.result_type(dtype_tuple) dtype = reduce(lambda x,y: np.result_type(x,y), dtype_tuple)
return dtype return dtype
def _get_axis_list_from_domain(self, domain=None): def _get_axis_list_from_domain(self, domain=None):
if domain is None: if domain is None:
domain = self.domain domain = self.domain
axis_list = [space.get_shape() for space in domain] axis_list = [tuple(ind for i in range(len(space.get_shape()))) for
  • Indices should be incremented:

    al = ((1,2,3),(4,5),(6,7),(8,))
    sub_space.calc_transform(data, axes=al[space_index])
    Edited by Theo Steininger
Please register or sign in to reply
ind, space in enumerate(domain)]
return axis_list return axis_list
def _parse_comm(self, comm): def _parse_comm(self, comm):
...@@ -281,7 +284,6 @@ class field(object): ...@@ -281,7 +284,6 @@ class field(object):
self.codomain = codomain self.codomain = codomain
return codomain return codomain
def get_random_values(self, **kwargs): def get_random_values(self, **kwargs):
raise NotImplementedError(about._errors.cstring( raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'enforce_power'.")) "ERROR: no generic instance method 'enforce_power'."))
...@@ -380,8 +382,8 @@ class field(object): ...@@ -380,8 +382,8 @@ class field(object):
def get_shape(self): def get_shape(self):
if len(self.domain) > 1: if len(self.domain) > 1:
global_shape = reduce(lambda x, y: x.get_shape() + y.get_shape(), shape_tuple = tuple(space.get_shape() for space in self.domain)
self.domain) global_shape = reduce(lambda x,y: x+y, shape_tuple)
else: else:
global_shape = self.domain[0].get_shape() global_shape = self.domain[0].get_shape()
...@@ -1076,7 +1078,7 @@ class field(object): ...@@ -1076,7 +1078,7 @@ class field(object):
""" """
return self._unary_operation(self.get_val(), op='median', return self._unary_operation(self.get_val(), op='median',
**kwargs) **kwargs)
def mean(self, **kwargs): def mean(self, **kwargs):
""" """
...@@ -1093,7 +1095,7 @@ class field(object): ...@@ -1093,7 +1095,7 @@ class field(object):
""" """
return self._unary_operation(self.get_val(), op='mean', return self._unary_operation(self.get_val(), op='mean',
**kwargs) **kwargs)
def std(self, **kwargs): def std(self, **kwargs):
""" """
...@@ -1110,7 +1112,7 @@ class field(object): ...@@ -1110,7 +1112,7 @@ class field(object):
""" """
return self._unary_operation(self.get_val(), op='std', return self._unary_operation(self.get_val(), op='std',
**kwargs) **kwargs)
def var(self, **kwargs): def var(self, **kwargs):
""" """
...@@ -1127,7 +1129,7 @@ class field(object): ...@@ -1127,7 +1129,7 @@ class field(object):
""" """
return self._unary_operation(self.get_val(), op='var', return self._unary_operation(self.get_val(), op='var',
**kwargs) **kwargs)
def argmin(self, split=False, **kwargs): def argmin(self, split=False, **kwargs):
""" """
...@@ -1153,10 +1155,10 @@ class field(object): ...@@ -1153,10 +1155,10 @@ class field(object):
""" """
if split: if split:
return self._unary_operation(self.get_val(), op='argmin_nonflat', return self._unary_operation(self.get_val(), op='argmin_nonflat',
**kwargs) **kwargs)
else: else:
return self._unary_operation(self.get_val(), op='argmin', return self._unary_operation(self.get_val(), op='argmin',
**kwargs) **kwargs)
def argmax(self, split=False, **kwargs): def argmax(self, split=False, **kwargs):
""" """
...@@ -1182,10 +1184,10 @@ class field(object): ...@@ -1182,10 +1184,10 @@ class field(object):
""" """
if split: if split:
return self._unary_operation(self.get_val(), op='argmax_nonflat', return self._unary_operation(self.get_val(), op='argmax_nonflat',
**kwargs) **kwargs)
else: else:
return self._unary_operation(self.get_val(), op='argmax', return self._unary_operation(self.get_val(), op='argmax',
**kwargs) **kwargs)
# TODO: Implement the full range of unary and binary operotions # TODO: Implement the full range of unary and binary operotions
......
...@@ -109,7 +109,7 @@ def generate_space_with_size(name, num): ...@@ -109,7 +109,7 @@ def generate_space_with_size(name, num):
'rg_space': rg_space((num, num)), 'rg_space': rg_space((num, num)),
'lm_space': lm_space(mmax=num+1, lmax=num+1), 'lm_space': lm_space(mmax=num+1, lmax=num+1),
'hp_space': hp_space(num), 'hp_space': hp_space(num),
'gl_space': gl_space(nlat=num, nlon=num), 'gl_space': gl_space(nlat=num, nlon=2*num-1),
} }
return space_dict[name] return space_dict[name]
...@@ -156,7 +156,7 @@ class Test_field_init2(unittest.TestCase): ...@@ -156,7 +156,7 @@ class Test_field_init2(unittest.TestCase):
assert (s.check_codomain(f.codomain[0])) assert (s.check_codomain(f.codomain[0]))
assert (s.get_shape() == f.get_shape()) assert (s.get_shape() == f.get_shape())
class Test_field_multiple_init(unittest.TestCase): class Test_field_multiple_rg_init(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
itertools.product([(1,)], itertools.product([(1,)],
[True], [True],
...@@ -182,6 +182,30 @@ class Test_field_multiple_init(unittest.TestCase): ...@@ -182,6 +182,30 @@ class Test_field_multiple_init(unittest.TestCase):
assert (s2.check_codomain(f.codomain[1])) assert (s2.check_codomain(f.codomain[1]))
assert (s1.get_shape() + s2.get_shape() == f.get_shape()) assert (s1.get_shape() + s2.get_shape() == f.get_shape())
class Test_field_multiple_init(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, point_like_spaces, [4]),
testcase_func_name=custom_name_func)
def test_multiple_space_init(self, space1, space2, shape):
s1 = generate_space_with_size(space1, shape)
s2 = generate_space_with_size(space2, shape)
f = field(domain=(s1, s2))
assert (f.domain[0] is s1)
assert (f.domain[1] is s2)
assert (s1.check_codomain(f.codomain[0]))
assert (s2.check_codomain(f.codomain[1]))
assert (s1.get_shape() + s2.get_shape() == f.get_shape())
s3 = generate_space_with_size('hp_space',shape)
f = field(domain=(s1, s2, s3))
assert (f.domain[0] is s1)
assert (f.domain[1] is s2)
assert (f.domain[2] is s3)
assert (s1.check_codomain(f.codomain[0]))
assert (s2.check_codomain(f.codomain[1]))
assert (s3.check_codomain(f.codomain[2]))
assert (s1.get_shape() + s2.get_shape() + s3.get_shape() ==
f.get_shape())
class Test_axis(unittest.TestCase): class Test_axis(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment