diff --git a/nifty_field.py b/nifty_field.py index 1541cea58ead7084de5efdd9e98cd0a47e15b559..bbd988f3c73ec571ef6839b15ee0fbc145fc3e3c 100644 --- a/nifty_field.py +++ b/nifty_field.py @@ -103,7 +103,7 @@ class field(object): """ def __init__(self, domain=None, val=None, codomain=None, - comm=gc['default_comm'], copy=False, dtype=np.dtype('float64'), + comm=gc['default_comm'], copy=False, dtype=None, datamodel='fftw', **kwargs): """ Sets the attributes for a field class instance. @@ -214,21 +214,24 @@ class field(object): if kwargs == {}: val = self.cast(0) else: - val = map(lambda z: self.get_random_values(domain = self.domain, - codomain=z, **kwargs), self.codomain) + val = map(lambda z: self.get_random_values(domain=self.domain, + codomain=z, + **kwargs), + self.codomain) self.set_val(new_val=val, copy=copy) def _get_dtype_from_domain(self, domain=None): if domain is None: domain = self.domain - dtype_tuple = tuple(space.dtype for space in domain) - dtype = np.result_type(dtype_tuple) + dtype_tuple = tuple(np.dtype(space.dtype) for space in domain) + dtype = reduce(lambda x,y: np.result_type(x,y), dtype_tuple) return dtype def _get_axis_list_from_domain(self, domain=None): if domain is None: domain = self.domain - axis_list = [space.get_shape() for space in domain] + axis_list = [tuple(ind for i in range(len(space.get_shape()))) for + ind, space in enumerate(domain)] return axis_list def _parse_comm(self, comm): @@ -281,7 +284,6 @@ class field(object): self.codomain = codomain return codomain - def get_random_values(self, **kwargs): raise NotImplementedError(about._errors.cstring( "ERROR: no generic instance method 'enforce_power'.")) @@ -380,8 +382,8 @@ class field(object): def get_shape(self): if len(self.domain) > 1: - global_shape = reduce(lambda x, y: x.get_shape() + y.get_shape(), - self.domain) + shape_tuple = tuple(space.get_shape() for space in self.domain) + global_shape = reduce(lambda x,y: x+y, shape_tuple) else: global_shape = self.domain[0].get_shape() @@ -1076,7 +1078,7 @@ class field(object): """ return self._unary_operation(self.get_val(), op='median', - **kwargs) + **kwargs) def mean(self, **kwargs): """ @@ -1093,7 +1095,7 @@ class field(object): """ return self._unary_operation(self.get_val(), op='mean', - **kwargs) + **kwargs) def std(self, **kwargs): """ @@ -1110,7 +1112,7 @@ class field(object): """ return self._unary_operation(self.get_val(), op='std', - **kwargs) + **kwargs) def var(self, **kwargs): """ @@ -1127,7 +1129,7 @@ class field(object): """ return self._unary_operation(self.get_val(), op='var', - **kwargs) + **kwargs) def argmin(self, split=False, **kwargs): """ @@ -1153,10 +1155,10 @@ class field(object): """ if split: return self._unary_operation(self.get_val(), op='argmin_nonflat', - **kwargs) + **kwargs) else: return self._unary_operation(self.get_val(), op='argmin', - **kwargs) + **kwargs) def argmax(self, split=False, **kwargs): """ @@ -1182,10 +1184,10 @@ class field(object): """ if split: return self._unary_operation(self.get_val(), op='argmax_nonflat', - **kwargs) + **kwargs) else: return self._unary_operation(self.get_val(), op='argmax', - **kwargs) + **kwargs) # TODO: Implement the full range of unary and binary operotions diff --git a/test/test_nifty_field.py b/test/test_nifty_field.py index 04790bc155ec6020ac3e46b5c285570790d30749..5a18f389d80a7d024f33ad3ed22483ca4288397c 100644 --- a/test/test_nifty_field.py +++ b/test/test_nifty_field.py @@ -109,7 +109,7 @@ def generate_space_with_size(name, num): 'rg_space': rg_space((num, num)), 'lm_space': lm_space(mmax=num+1, lmax=num+1), 'hp_space': hp_space(num), - 'gl_space': gl_space(nlat=num, nlon=num), + 'gl_space': gl_space(nlat=num, nlon=2*num-1), } return space_dict[name] @@ -156,7 +156,7 @@ class Test_field_init2(unittest.TestCase): assert (s.check_codomain(f.codomain[0])) assert (s.get_shape() == f.get_shape()) -class Test_field_multiple_init(unittest.TestCase): +class Test_field_multiple_rg_init(unittest.TestCase): @parameterized.expand( itertools.product([(1,)], [True], @@ -182,6 +182,30 @@ class Test_field_multiple_init(unittest.TestCase): assert (s2.check_codomain(f.codomain[1])) assert (s1.get_shape() + s2.get_shape() == f.get_shape()) +class Test_field_multiple_init(unittest.TestCase): + @parameterized.expand( + itertools.product(point_like_spaces, point_like_spaces, [4]), + testcase_func_name=custom_name_func) + def test_multiple_space_init(self, space1, space2, shape): + s1 = generate_space_with_size(space1, shape) + s2 = generate_space_with_size(space2, shape) + f = field(domain=(s1, s2)) + assert (f.domain[0] is s1) + assert (f.domain[1] is s2) + assert (s1.check_codomain(f.codomain[0])) + assert (s2.check_codomain(f.codomain[1])) + assert (s1.get_shape() + s2.get_shape() == f.get_shape()) + s3 = generate_space_with_size('hp_space',shape) + f = field(domain=(s1, s2, s3)) + assert (f.domain[0] is s1) + assert (f.domain[1] is s2) + assert (f.domain[2] is s3) + assert (s1.check_codomain(f.codomain[0])) + assert (s2.check_codomain(f.codomain[1])) + assert (s3.check_codomain(f.codomain[2])) + assert (s1.get_shape() + s2.get_shape() + s3.get_shape() == + f.get_shape()) + class Test_axis(unittest.TestCase): @parameterized.expand(