Commit 988981a0 by Ultima

### Some bugfixes.

```Added direct_dot to nifty_simple_math.py.
Begun with direct probing implementation.```
parent 148e6fbc
PKG-INFO 0 → 100644
 Metadata-Version: 1.0 Name: ift_nifty Version: 1.0.6 Summary: Numerical Information Field Theory Home-page: http://www.mpa-garching.mpg.de/ift/nifty/ Author: Theo Steininger Author-email: theos@mpa-garching.mpg.de License: GPLv3 Description: UNKNOWN Platform: UNKNOWN
 ... @@ -214,9 +214,9 @@ class lm_space(point_space): ... @@ -214,9 +214,9 @@ class lm_space(point_space): return self.paradict['mmax'] return self.paradict['mmax'] def shape(self): def shape(self): mmax = self.paradict('mmax') mmax = self.paradict['mmax'] lmax = self.paradict('lmax') lmax = self.paradict['lmax'] return np.array([(mmax+1)*(lmax+1)-(lmax+1)*(mmax//2)], dtype=int) return np.array([(mmax+1)*(lmax+1)-((lmax+1)*lmax)//2], dtype=int) def dim(self,split=False): def dim(self,split=False): """ """ ... @@ -654,6 +654,7 @@ class lm_space(point_space): ... @@ -654,6 +654,7 @@ class lm_space(point_space): else: else: return self._dotlm(x,y) return self._dotlm(x,y) ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def calc_transform(self,x,codomain=None,**kwargs): def calc_transform(self,x,codomain=None,**kwargs): ... ...
 ... @@ -2813,13 +2813,7 @@ class field(object): ... @@ -2813,13 +2813,7 @@ class field(object): New field values either as a constant or an arbitrary array. New field values either as a constant or an arbitrary array. """ """ ''' self.val = new_val if(new_val is None): self.val = np.zeros(self.dim(split=True),dtype=self.domain.datatype,order='C') else: self.val = self.domain.enforce_values(new_val,extend=True) ''' self.val = self.domain.cast(new_val) return self.val return self.val def get_val(self): def get_val(self): ... @@ -2913,17 +2907,20 @@ class field(object): ... @@ -2913,17 +2907,20 @@ class field(object): elif isinstance(x, field): elif isinstance(x, field): ## if x lives in the cospace, transform it an make a ## if x lives in the cospace, transform it an make a ## recursive call ## recursive call if self.domain.fourier != x.domain.fourier: try: return self.dot(x = x.transform()) if self.domain.fourier != x.domain.fourier: else: return self.dot(x = x.transform()) except(AttributeError): pass ## whether the domain matches exactly or not: ## whether the domain matches exactly or not: ## extract the data from x and try to dot with this ## extract the data from x and try to dot with this return self.dot(x = x.get_val()) return self.dot(x = x.get_val()) ## Case 3: x is something else ## Case 3: x is something else else: else: ## Cast the input in order to cure datatype and shape differences ## Cast the input in order to cure datatype and shape differences casted_x = self.cast(x) casted_x = self.domain.cast(x) ## Compute the dot respecting the fact of discrete/continous spaces ## Compute the dot respecting the fact of discrete/continous spaces if self.domain.discrete == True: if self.domain.discrete == True: return self.domain.calc_dot(self.get_val(), casted_x) return self.domain.calc_dot(self.get_val(), casted_x) ... ...
 ... @@ -411,7 +411,7 @@ class distributed_data_object(object): ... @@ -411,7 +411,7 @@ class distributed_data_object(object): found_boolean = (key.dtype == np.bool) found_boolean = (key.dtype == np.bool) else: else: found = 'other' found = 'other' ## TODO: transfer this into distributor: if (found == 'ndarray' or found == 'd2o') and found_boolean == True: if (found == 'ndarray' or found == 'd2o') and found_boolean == True: ## extract the data of local relevance ## extract the data of local relevance local_bool_array = self.distributor.extract_local_data(key) local_bool_array = self.distributor.extract_local_data(key) ... @@ -1541,7 +1541,12 @@ class dtype_converter: ... @@ -1541,7 +1541,12 @@ class dtype_converter: self._to_np_dict = dict(to_np_pre_dict) self._to_np_dict = dict(to_np_pre_dict) def dictionize_np(self, x): def dictionize_np(self, x): return frozenset(x.__dict__.items()) dic = x.__dict__.items() if x is np.float: dic[24] = 0 dic[29] = 0 dic[37] = 0 return frozenset(dic) def dictionize_mpi(self, x): def dictionize_mpi(self, x): return x.name return x.name ... ...
 ... @@ -483,5 +483,30 @@ def conjugate(x): ... @@ -483,5 +483,30 @@ def conjugate(x): """ """ return _math_helper(x, np.conjugate) return _math_helper(x, np.conjugate) def direct_dot(x, y): ## the input could be fields. Try to extract the data try: x = x.get_val() except(AttributeError): pass ## try to make a direct vdot try: return x.vdot(y) except(AttributeError): pass try: return y.vdot(x) except(AttributeError): pass ## fallback to numpy return np.vdot(x, y) ##--------------------------------- ##--------------------------------- \ No newline at end of file