diff --git a/nifty_core.py b/nifty_core.py index 4b7b9bfdc394c5a455b71d3bc7446d418417043d..350aa4d05c5c0d807ab91c9aa69e37d3b872ee2b 100644 --- a/nifty_core.py +++ b/nifty_core.py @@ -911,9 +911,11 @@ class point_space(space): 'mean': lambda y: getattr(y, 'mean')(axis=axis), 'std': lambda y: getattr(y, 'std')(axis=axis), 'var': lambda y: getattr(y, 'var')(axis=axis), - 'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(axis=axis), + 'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')( + axis=axis), 'argmin': lambda y: getattr(y, 'argmin')(axis=axis), - 'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(axis=axis), + 'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')( + axis=axis), 'argmax': lambda y: getattr(y, 'argmax')(axis=axis), 'conjugate': lambda y: getattr(y, 'conjugate')(), 'sum': lambda y: getattr(y, 'sum')(axis=axis), @@ -1038,25 +1040,7 @@ class point_space(space): return self.calc_weight(mol, power=1) def cast(self, x=None, dtype=None, **kwargs): - if dtype is not None: - dtype = np.dtype(dtype) - - # If x is a field, extract the data and do a recursive call - if isinstance(x, field): - # Check if the domain matches - if self != x.domain: - about.warnings.cflush( - "WARNING: Getting data from foreign domain!") - # Extract the data, whatever it is, and cast it again - return self.cast(x.val, - dtype=dtype, - **kwargs) - - else: - return self._cast_to_d2o(x=x, - dtype=dtype, - **kwargs) - + return self._cast_to_d2o(x=x, dtype=dtype, **kwargs) def _cast_to_d2o(self, x, dtype=None, **kwargs): """ @@ -1081,6 +1065,8 @@ class point_space(space): Whether the method should raise a warning if information is lost during casting (default: False). """ + if dtype is not None: + dtype = np.dtype(dtype) if dtype is None: dtype = self.dtype @@ -1357,8 +1343,8 @@ class point_space(space): processed_std = std else: try: - processed_std = sample.distributor.\ - extract_local_data(std) + processed_std = sample.distributor. \ + extract_local_data(std) except(AttributeError): processed_std = std @@ -1375,8 +1361,6 @@ class point_space(space): vmax=arg['vmax'])) return sample - - def calc_weight(self, x, power=1): """ Weights a given array of field values with the pixel volumes (not @@ -1575,7 +1559,7 @@ class point_space(space): ax0 = fig.add_axes([0.12, 0.12, 0.82, 0.76]) xaxes = np.arange(self.para[0], dtype=np.dtype('int')) - if(norm == "log")and(vmin <= 0): + if (norm == "log") and (vmin <= 0): raise ValueError(about._errors.cstring( "ERROR: nonpositive value(s).")) @@ -1741,8 +1725,9 @@ class field(object): """ - def __init__(self, domain=None, val=None, codomain=None, ishape=None, - copy=False, **kwargs): + def __init__(self, domain=None, val=None, codomain=None, + copy=False, dtype=np.dtype('float64'), datamodel='not', + **kwargs): """ Sets the attributes for a field class instance. @@ -1771,32 +1756,31 @@ class field(object): self._init_from_field(f=val, domain=domain, codomain=codomain, - ishape=ishape, copy=copy, + dtype=dtype, + datamodel=datamodel, **kwargs) else: self._init_from_array(val=val, domain=domain, codomain=codomain, - ishape=ishape, copy=copy, + dtype=dtype, + datamodel=datamodel, **kwargs) - def _init_from_field(self, f, domain, codomain, ishape, copy, **kwargs): + def _init_from_field(self, f, domain, codomain, copy, dtype, datamodel, + **kwargs): # check domain if domain is None: domain = f.domain # check codomain if codomain is None: - if domain.check_codomain(f.codomain): + if self.check_codomain(domain, f.codomain): codomain = f.codomain else: - codomain = domain.get_codomain() - - # check for ishape - if ishape is None: - ishape = f.ishape + codomain = self.get_codomain(domain) # Check if the given field lives in a space which is compatible to the # given domain @@ -1808,51 +1792,78 @@ class field(object): self._init_from_array(domain=domain, val=f.val, codomain=codomain, - ishape=ishape, copy=copy, + dtype=dtype, + datamodel=datamodel, **kwargs) - def _init_from_array(self, val, domain, codomain, ishape, copy, **kwargs): + def _init_from_array(self, val, domain, codomain, copy, dtype, datamodel, + **kwargs): + if dtype is None: + dtype = np.dtype('float64') + self.dtype = dtype + + if datamodel not in DISTRIBUTION_STRATEGIES['global']: + about.warnings.cprint("WARNING: datamodel set to default.") + self.datamodel = \ + gc['default_distribution_strategy'] + else: + self.datamodel = datamodel # check domain - if not isinstance(domain, space): - raise TypeError(about._errors.cstring( - "ERROR: Given domain is not a space.")) - self.domain = domain + self.domain = self.check_valid_domain(domain=domain) # check codomain if codomain is None: - codomain = domain.get_codomain() - elif not self.domain.check_codomain(codomain): + codomain = self.get_codomain(domain) + elif not self.check_codomain(domain=domain, codomain=codomain): raise ValueError(about._errors.cstring( "ERROR: The given codomain is not compatible to the domain.")) self.codomain = codomain - if ishape is not None: - ishape = tuple(np.array(ishape, dtype=np.uint).flatten()) - elif val is not None: - try: - if val.dtype.type == np.object_: - ishape = val.shape - else: - ishape = () - except(AttributeError): - try: - ishape = val.ishape - except(AttributeError): - ishape = () - else: - ishape = () - self.ishape = ishape - if val is None: if kwargs == {}: - val = self._map(lambda: self.domain.cast(0.)) + val = self._map(lambda: self.cast((0,))) else: val = self._map(lambda: self.domain.get_random_values( codomain=self.codomain, **kwargs)) self.set_val(new_val=val, copy=copy) + def check_valid_domain(self, domain): + if not isinstance(domain, np.ndarray): + raise TypeError(about._errors.cstring( + "ERROR: The given domain is not a list.")) + for d in domain: + if not isinstance(d, space): + raise TypeError(about._errors.cstring( + "ERROR: Given domain is not a space.")) + elif d.dtype != self.dtype: + raise AttributeError(about._errors.cstring( + "ERROR: The dtype of a space in the domain missmatches " + "the field's dtype.")) + elif d.datamodel != self.datamodel: + raise AttributeError(about._errors.cstring( + "ERROR: The datamodel of a space in the domain missmatches " + "the field's datamodel.")) + return domain + + def check_codomain(self, domain, codomain): + if codomain is None: + return False + if domain.shape == codomain.shape: + return np.all(map((lambda d, c: d._check_codomain(c)), domain, + codomain)) + else: + return False + + def get_codomain(self, domain): + if domain.shape == (1,): + return np.array(domain[0].get_codomain()) + else: + # TODO implement for multiple domain get_codomain need + # calc_transform + return np.empty((0,)) + def __len__(self): return int(self.get_dim(split=True)[0]) @@ -2010,29 +2021,127 @@ class field(object): def get_ishape(self): return self.ishape + def get_global_shape(self): + global_shape = np.sum([space.get_shape() for space in self.domain]) + if isinstance(global_shape, tuple): + return global_shape + else: + return () + def _map(self, function, *args): - return utilities.field_map(self.ishape, function, *args) + return utilities.field_map(self.get_global_shape(), function, *args) - def cast(self, x=None, ishape=None): - if ishape is None: - ishape = self.ishape - casted_x = self._cast_to_ishape(x, ishape=ishape) - if ishape == (): + def cast(self, x=None, dtype=None): + if dtype is not None: + dtype = np.dtype(dtype) + if dtype is None: + dtype = self.dtype + casted_x = self._cast_to_shape(x) + if self.get_global_shape() == (): return self.domain.cast(casted_x) else: return self._map(lambda z: self.domain.cast(z), casted_x) - def _cast_to_ishape(self, x, ishape=None): - if ishape is None: - ishape = self.ishape + def _cast_to_d2o(self, x, dtype=None, **kwargs): + """ + Computes valid field values from a given object, trying + to translate the given data into a valid form. Thereby it is as + benevolent as possible. + Parameters + ---------- + x : {float, numpy.ndarray, nifty.field} + Object to be transformed into an array of valid field values. + + Returns + ------- + x : numpy.ndarray, distributed_data_object + Array containing the field values, which are compatible to the + space. + + Other parameters + ---------------- + verbose : bool, *optional* + Whether the method should raise a warning if information is + lost during casting (default: False). + """ + + if dtype is None: + dtype = self.dtype + + # Case 1: x is a distributed_data_object + if isinstance(x, distributed_data_object): + to_copy = False + + # Check the shape + if np.any(np.array(x.shape) != np.array(self.get_shape())): + # Check if at least the number of degrees of freedom is equal + if x.get_dim() == self.get_dim(): + try: + temp = x.copy_empty(global_shape=self.get_shape()) + temp.set_local_data(x.get_local_data(), copy=False) + except: + # If the number of dof is equal or 1, use np.reshape... + about.warnings.cflush( + "WARNING: Trying to reshape the data. This " + + "operation is expensive as it consolidates the " + + "full data!\n") + temp = x.get_full_data() + temp = np.reshape(temp, self.get_shape()) + # ... and cast again + return self._cast_to_d2o(temp, + dtype=dtype, + **kwargs) + + else: + raise ValueError(about._errors.cstring( + "ERROR: Data has incompatible shape!")) + + # Check the dtype + if x.dtype != dtype: + if x.dtype > dtype: + about.warnings.cflush( + "WARNING: Datatypes are of conflicting precision " + + "(own: " + str(dtype) + " <> foreign: " + + str(x.dtype) + ") and will be casted! Potential " + + "loss of precision!\n") + to_copy = True + + # Check the distribution_strategy + if x.distribution_strategy != self.datamodel: + to_copy = True + + if to_copy: + temp = x.copy_empty(dtype=dtype, + distribution_strategy=self.datamodel) + temp.set_data(to_key=(slice(None),), + data=x, + from_key=(slice(None),)) + temp.hermitian = x.hermitian + x = temp + + return x + + # Case 2: x is something else + # Use general d2o casting + else: + x = distributed_data_object(x, + global_shape=self.get_shape(), + dtype=dtype, + distribution_strategy=self.datamodel) + # Cast the d2o + return self.cast(x, dtype=dtype) + + def _cast_to_shape(self, x): if isinstance(x, field): x = x.get_val() - if ishape == (): + + global_shape = self.get_global_shape() + if global_shape == (): casted_x = self._cast_to_scalar_helper(x) else: - casted_x = self._cast_to_tensor_helper(x, ishape) + casted_x = self._cast_to_tensor_helper(x, shape=global_shape) return casted_x def _cast_to_scalar_helper(self, x): @@ -2068,26 +2177,26 @@ class field(object): # In all other cases, cast x directly return x - def _cast_to_tensor_helper(self, x, ishape=None): - if ishape is None: - ishape = self.ishape + def _cast_to_tensor_helper(self, x, shape=None): + if shape is None: + shape = self.get_global_shape() # Check if x is a container of proper length # containing something which will then checked by the domain-space x_shape = np.shape(x) - self_shape = self.domain.get_shape() + self_shape = self.get_global_shape() try: container_Q = (x.dtype.type == np.object_) except(AttributeError): container_Q = False if container_Q: - if x_shape == ishape: + if x_shape == shape: return x - elif x_shape == ishape[:len(x_shape)]: + elif x_shape == shape[:len(x_shape)]: return x.reshape(x_shape + - (1,) * (len(ishape) - len(x_shape))) + (1,) * (len(shape) - len(x_shape))) # Slow track: x could be a pure ndarray @@ -2095,11 +2204,11 @@ class field(object): # 1: There are cases where np.shape will only find the container # although it was no np.object array; e.g. for [a,1]. # 2: The overall shape is already the right one - if x_shape == ishape or x_shape == (ishape + self_shape): + if x_shape == shape or x_shape == (shape + self_shape): # Iterate over the outermost dimension and cast the inner spaces - result = np.empty(ishape, dtype=np.object) - for i in xrange(np.prod(ishape)): - ii = np.unravel_index(i, ishape) + result = np.empty(shape, dtype=np.object) + for i in xrange(np.prod(shape)): + ii = np.unravel_index(i, shape) try: result[ii] = x[ii] except(TypeError): @@ -2112,16 +2221,16 @@ class field(object): # Check if the input has shape (1, self.domain.shape) # Iterate over the outermost dimension and cast the inner spaces elif x_shape == ((1,) + self_shape): - result = np.empty(ishape, dtype=np.object) - for i in xrange(np.prod(ishape)): - ii = np.unravel_index(i, ishape) + result = np.empty(shape, dtype=np.object) + for i in xrange(np.prod(shape)): + ii = np.unravel_index(i, shape) result[ii] = x[0] # Case 4: fallback: try to cast x with self.domain else: # Iterate over the outermost dimension and cast the inner spaces - result = np.empty(ishape, dtype=np.object) - for i in xrange(np.prod(ishape)): - ii = np.unravel_index(i, ishape) + result = np.empty(shape, dtype=np.object) + for i in xrange(np.prod(shape)): + ii = np.unravel_index(i, shape) result[ii] = x return result @@ -2903,4 +3012,4 @@ class field(object): class EmptyField(field): def __init__(self): - pass \ No newline at end of file + pass diff --git a/test/test_nifty_field.py b/test/test_nifty_field.py index 7779dbf19a4929d478039de12a18f22f3ba080b9..2961652af50b3fe56d044cb6d08d579f7cfc57a5 100644 --- a/test/test_nifty_field.py +++ b/test/test_nifty_field.py @@ -56,11 +56,11 @@ all_hp_datatypes = [np.dtype('float64')] ############################################################################### DATAMODELS = {} -DATAMODELS['point_space'] = ['np'] + POINT_DISTRIBUTION_STRATEGIES -DATAMODELS['rg_space'] = ['np'] + RG_DISTRIBUTION_STRATEGIES -DATAMODELS['lm_space'] = ['np'] + LM_DISTRIBUTION_STRATEGIES -DATAMODELS['gl_space'] = ['np'] + GL_DISTRIBUTION_STRATEGIES -DATAMODELS['hp_space'] = ['np'] + HP_DISTRIBUTION_STRATEGIES +DATAMODELS['point_space'] = POINT_DISTRIBUTION_STRATEGIES +DATAMODELS['rg_space'] = RG_DISTRIBUTION_STRATEGIES +DATAMODELS['lm_space'] = LM_DISTRIBUTION_STRATEGIES +DATAMODELS['gl_space'] = GL_DISTRIBUTION_STRATEGIES +DATAMODELS['hp_space'] = HP_DISTRIBUTION_STRATEGIES ############################################################################### @@ -110,10 +110,9 @@ class Test_field_init(unittest.TestCase): @parameterized.expand(space_list) def test_successfull_init_and_attributes(self, s): - s = s[0] - f = field(s) - assert(f.domain is s) - assert(s.check_codomain(f.codomain)) + f = field(domain=np.array([s]), dtype=s.dtype) + assert(f.domain[0] is s) + assert(s.check_codomain(f.codomain[0]))