Fix _axis_list functionality
... | ... | @@ -103,7 +103,7 @@ class field(object): |
""" | ||
def __init__(self, domain=None, val=None, codomain=None, | ||
comm=gc['default_comm'], copy=False, dtype=np.dtype('float64'), | ||
comm=gc['default_comm'], copy=False, dtype=None, | ||
datamodel='fftw', **kwargs): | ||
""" | ||
Sets the attributes for a field class instance. | ||
... | ... | @@ -214,21 +214,24 @@ class field(object): |
if kwargs == {}: | ||
val = self.cast(0) | ||
else: | ||
val = map(lambda z: self.get_random_values(domain = self.domain, | ||
codomain=z, **kwargs), self.codomain) | ||
val = map(lambda z: self.get_random_values(domain=self.domain, | ||
codomain=z, | ||
**kwargs), | ||
self.codomain) | ||
self.set_val(new_val=val, copy=copy) | ||
def _get_dtype_from_domain(self, domain=None): | ||
if domain is None: | ||
domain = self.domain | ||
dtype_tuple = tuple(space.dtype for space in domain) | ||
dtype = np.result_type(dtype_tuple) | ||
dtype_tuple = tuple(np.dtype(space.dtype) for space in domain) | ||
dtype = reduce(lambda x,y: np.result_type(x,y), dtype_tuple) | ||
return dtype | ||
def _get_axis_list_from_domain(self, domain=None): | ||
if domain is None: | ||
domain = self.domain | ||
axis_list = [space.get_shape() for space in domain] | ||
axis_list = [tuple(ind for i in range(len(space.get_shape()))) for | ||
|
||
ind, space in enumerate(domain)] | ||
return axis_list | ||
def _parse_comm(self, comm): | ||
... | ... | @@ -281,7 +284,6 @@ class field(object): |
self.codomain = codomain | ||
return codomain | ||
def get_random_values(self, **kwargs): | ||
raise NotImplementedError(about._errors.cstring( | ||
"ERROR: no generic instance method 'enforce_power'.")) | ||
... | ... | @@ -380,8 +382,8 @@ class field(object): |
def get_shape(self): | ||
if len(self.domain) > 1: | ||
global_shape = reduce(lambda x, y: x.get_shape() + y.get_shape(), | ||
self.domain) | ||
shape_tuple = tuple(space.get_shape() for space in self.domain) | ||
global_shape = reduce(lambda x,y: x+y, shape_tuple) | ||
else: | ||
global_shape = self.domain[0].get_shape() | ||
... | ... | @@ -1076,7 +1078,7 @@ class field(object): |
""" | ||
return self._unary_operation(self.get_val(), op='median', | ||
**kwargs) | ||
**kwargs) | ||
def mean(self, **kwargs): | ||
""" | ||
... | ... | @@ -1093,7 +1095,7 @@ class field(object): |
""" | ||
return self._unary_operation(self.get_val(), op='mean', | ||
**kwargs) | ||
**kwargs) | ||
def std(self, **kwargs): | ||
""" | ||
... | ... | @@ -1110,7 +1112,7 @@ class field(object): |
""" | ||
return self._unary_operation(self.get_val(), op='std', | ||
**kwargs) | ||
**kwargs) | ||
def var(self, **kwargs): | ||
""" | ||
... | ... | @@ -1127,7 +1129,7 @@ class field(object): |
""" | ||
return self._unary_operation(self.get_val(), op='var', | ||
**kwargs) | ||
**kwargs) | ||
def argmin(self, split=False, **kwargs): | ||
""" | ||
... | ... | @@ -1153,10 +1155,10 @@ class field(object): |
""" | ||
if split: | ||
return self._unary_operation(self.get_val(), op='argmin_nonflat', | ||
**kwargs) | ||
**kwargs) | ||
else: | ||
return self._unary_operation(self.get_val(), op='argmin', | ||
**kwargs) | ||
**kwargs) | ||
def argmax(self, split=False, **kwargs): | ||
""" | ||
... | ... | @@ -1182,10 +1184,10 @@ class field(object): |
""" | ||
if split: | ||
return self._unary_operation(self.get_val(), op='argmax_nonflat', | ||
**kwargs) | ||
**kwargs) | ||
else: | ||
return self._unary_operation(self.get_val(), op='argmax', | ||
**kwargs) | ||
**kwargs) | ||
# TODO: Implement the full range of unary and binary operotions | ||
... | ... |