Commit 988981a0 by Ultima

### Some bugfixes.

```Added direct_dot to nifty_simple_math.py.
Begun with direct probing implementation.```
parent 148e6fbc
PKG-INFO 0 → 100644
 Metadata-Version: 1.0 Name: ift_nifty Version: 1.0.6 Summary: Numerical Information Field Theory Home-page: http://www.mpa-garching.mpg.de/ift/nifty/ Author: Theo Steininger Author-email: theos@mpa-garching.mpg.de License: GPLv3 Description: UNKNOWN Platform: UNKNOWN
 ... ... @@ -214,9 +214,9 @@ class lm_space(point_space): return self.paradict['mmax'] def shape(self): mmax = self.paradict('mmax') lmax = self.paradict('lmax') return np.array([(mmax+1)*(lmax+1)-(lmax+1)*(mmax//2)], dtype=int) mmax = self.paradict['mmax'] lmax = self.paradict['lmax'] return np.array([(mmax+1)*(lmax+1)-((lmax+1)*lmax)//2], dtype=int) def dim(self,split=False): """ ... ... @@ -654,6 +654,7 @@ class lm_space(point_space): else: return self._dotlm(x,y) ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def calc_transform(self,x,codomain=None,**kwargs): ... ...
 ... ... @@ -2813,13 +2813,7 @@ class field(object): New field values either as a constant or an arbitrary array. """ ''' if(new_val is None): self.val = np.zeros(self.dim(split=True),dtype=self.domain.datatype,order='C') else: self.val = self.domain.enforce_values(new_val,extend=True) ''' self.val = self.domain.cast(new_val) self.val = new_val return self.val def get_val(self): ... ... @@ -2913,17 +2907,20 @@ class field(object): elif isinstance(x, field): ## if x lives in the cospace, transform it an make a ## recursive call if self.domain.fourier != x.domain.fourier: return self.dot(x = x.transform()) else: try: if self.domain.fourier != x.domain.fourier: return self.dot(x = x.transform()) except(AttributeError): pass ## whether the domain matches exactly or not: ## extract the data from x and try to dot with this return self.dot(x = x.get_val()) return self.dot(x = x.get_val()) ## Case 3: x is something else else: ## Cast the input in order to cure datatype and shape differences casted_x = self.cast(x) casted_x = self.domain.cast(x) ## Compute the dot respecting the fact of discrete/continous spaces if self.domain.discrete == True: return self.domain.calc_dot(self.get_val(), casted_x) ... ...
 ... ... @@ -411,7 +411,7 @@ class distributed_data_object(object): found_boolean = (key.dtype == np.bool) else: found = 'other' ## TODO: transfer this into distributor: if (found == 'ndarray' or found == 'd2o') and found_boolean == True: ## extract the data of local relevance local_bool_array = self.distributor.extract_local_data(key) ... ... @@ -1541,7 +1541,12 @@ class dtype_converter: self._to_np_dict = dict(to_np_pre_dict) def dictionize_np(self, x): return frozenset(x.__dict__.items()) dic = x.__dict__.items() if x is np.float: dic[24] = 0 dic[29] = 0 dic[37] = 0 return frozenset(dic) def dictionize_mpi(self, x): return x.name ... ...
 ... ... @@ -483,5 +483,30 @@ def conjugate(x): """ return _math_helper(x, np.conjugate) def direct_dot(x, y): ## the input could be fields. Try to extract the data try: x = x.get_val() except(AttributeError): pass ## try to make a direct vdot try: return x.vdot(y) except(AttributeError): pass try: return y.vdot(x) except(AttributeError): pass ## fallback to numpy return np.vdot(x, y) ##--------------------------------- \ No newline at end of file
 ... ... @@ -29,6 +29,7 @@ from nifty.nifty_core import space, \ from nifty_minimization import conjugate_gradient from nifty_probing import trace_probing, \ diagonal_probing from nifty_mpi_probing import prober ##============================================================================= ... ... @@ -165,6 +166,33 @@ class operator(object): self.para = para @property def val(self): return self.__val @val.setter def val(self, x): self.__val = self.domain.cast(x) def set_val(self, new_val): """ Resets the field values. Parameters ---------- new_val : {scalar, ndarray} New field values either as a constant or an arbitrary array. """ self.val = new_val return self.val def get_val(self): return self.val ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def nrow(self): ... ... @@ -301,7 +329,7 @@ class operator(object): if self.domain == self.target != x.domain: x_ = x_.transform(target=x.domain) if x_.domain == x.domain and (x_.target is not x.target): x_.set_target(newtarget = x.target) x_.set_target(new_target = x.target) return x_ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ... ... @@ -1190,27 +1218,9 @@ class diagonal_operator(operator): self.domain = domain self.target = self.domain ## Set the diag-val self.val = self.domain.cast(diag) ## Weight if necessary if self.domain.discrete == False and bare == True: self.val = self.domain.calc_weight(self.val, power = 1) ## Check complexity attributes if self.domain.calc_real_Q(self.val) == True: self.sym = True else: self.sym = False ## Check if unitary, i.e. identity if (self.val == 1).all() == True: self.uni = True else: self.uni = False self.imp = True self.set_diag(new_diag = diag) """ if(domain is None)and(isinstance(diag,field)): ... ... @@ -1243,13 +1253,13 @@ class diagonal_operator(operator): """ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def set_diag(self,newdiag,bare=False): def set_diag(self, new_diag, bare=False): """ Sets the diagonal of the diagonal operator Parameters ---------- newdiag : {scalar, ndarray, field} new_diag : {scalar, ndarray, field} The new diagonal entries of the operator. For a scalar, a constant diagonal is defined having the value provided. If no domain is given, diag must be a field. ... ... @@ -1263,7 +1273,28 @@ class diagonal_operator(operator): ------- None """ newdiag = self.domain.enforce_values(newdiag,extend=True) ## Set the diag-val self.val = self.domain.cast(new_diag) ## Weight if necessary if self.domain.discrete == False and bare == True: self.val = self.domain.calc_weight(self.val, power = 1) ## Check complexity attributes if self.domain.calc_real_Q(self.val) == True: self.sym = True else: self.sym = False ## Check if unitary, i.e. identity if (self.val == 1).all() == True: self.uni = True else: self.uni = False """ newdiag = self.domain.enforce_values(newdiag, extend=True) ## weight if ... if(not self.domain.discrete)and(bare): newdiag = self.domain.calc_weight(newdiag,power=1) ... ... @@ -1281,21 +1312,50 @@ class diagonal_operator(operator): self.uni = True else: self.uni = False """ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def _multiply(self,x,**kwargs): ## > applies the operator to a given field def _multiply(self, x, **kwargs): ## > applies the operator to a given field y = x.copy(domain = self.target, target = self.target.get_codomain()) y *= self.get_val() return y """ x_ = field(self.target,val=None,target=x.target) x_.val = x.val*self.val ## bypasses self.domain.enforce_values return x_ """ def _adjoint_multiply(self,x,**kwargs): ## > applies the adjoint operator to a given field def _adjoint_multiply(self, x, **kwargs): ## > applies the adjoint operator to a given field y = x.copy(domain = self.domain, target = self.domain.get_codomain()) y *= self.get_val().conjugate() return y """ x_ = field(self.domain,val=None,target=x.target) x_.val = x.val*np.conjugate(self.val) ## bypasses self.domain.enforce_values return x_ """ def _inverse_multiply(self,x,pseudo=False,**kwargs): ## > applies the inverse operator to a given field if(np.any(self.val==0)): def _inverse_multiply(self, x, pseudo=False, **kwargs): ## > applies the inverse operator to a given field y = x.copy(domain = self.domain, target = self.domain.get_codomain()) if (self.get_val() == 0).any(): if pseudo == False: raise AttributeError(about._errors.cstring( "ERROR: singular operator.")) else: y /= self.get_val() y[y == np.nan] = 0 y[y == np.inf] = 0 else: y /= self.get_val() return y """ if (np.any(self.val==0)): if(pseudo): x_ = field(self.domain,val=None,target=x.target) x_.val = np.ma.filled(x.val/np.ma.masked_where(self.val==0,self.val,copy=False),fill_value=0) ## bypasses self.domain.enforce_values ... ... @@ -1306,8 +1366,25 @@ class diagonal_operator(operator): x_ = field(self.domain,val=None,target=x.target) x_.val = x.val/self.val ## bypasses self.domain.enforce_values return x_ def _adjoint_inverse_multiply(self,x,pseudo=False,**kwargs): ## > applies the inverse adjoint operator to a given field """ def _adjoint_inverse_multiply(self, x, pseudo=False, **kwargs): ## > applies the inverse adjoint operator to a given field y = x.copy(domain = self.target, target = self.target.get_codomain()) if (self.get_val() == 0).any(): if pseudo == False: raise AttributeError(about._errors.cstring( "ERROR: singular operator.")) else: y /= self.get_val().conjugate() y[y == np.nan] = 0 y[y == np.inf] = 0 else: y /= self.get_val().conjugate() return y """ if(np.any(self.val==0)): if(pseudo): x_ = field(self.domain,val=None,target=x.target) ... ... @@ -1319,8 +1396,11 @@ class diagonal_operator(operator): x_ = field(self.target,val=None,target=x.target) x_.val = x.val/np.conjugate(self.val) ## bypasses self.domain.enforce_values return x_ def _inverse_adjoint_multiply(self,x,pseudo=False,**kwargs): ## > applies the adjoint inverse operator to a given field """ def _inverse_adjoint_multiply(self, x, pseudo=False, **kwargs): ## > applies the adjoint inverse operator to a given field return self._adjoint_inverse_multiply(x, pseudo = pseudo, **kwargs) """ if(np.any(self.val==0)): if(pseudo): x_ = field(self.domain,val=None,target=x.target) ... ... @@ -1333,7 +1413,7 @@ class diagonal_operator(operator): x_ = field(self.target,val=None,target=x.target) x_.val = x.val*np.conjugate(1/self.val) ## bypasses self.domain.enforce_values return x_ """ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ def tr(self,domain=None,**kwargs): ... ... @@ -2400,13 +2480,13 @@ class projection_operator(operator): if(x_.domain.nest[-1]!=x.domain): x_ = x_.transform(target=nested_space([point_space(len(self.ind),datatype=x.domain.datatype),x.domain]),overwrite=False) ## ... domain if(x_.target.nest[-1]!=x.target): x_.set_target(newtarget=nested_space([point_space(len(self.ind),datatype=x.target.datatype),x.target])) ## ... codomain x_.set_target(new_target=nested_space([point_space(len(self.ind),datatype=x.target.datatype),x.target])) ## ... codomain else: ## repair ... if(x_.domain!=x.domain): x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain if(x_.target!=x.target): x_.set_target(newtarget=x.target) ## ... codomain x_.set_target(new_target=x.target) ## ... codomain return x_ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ... ... @@ -3071,7 +3151,7 @@ class response_operator(operator): if(self.domain==self.target!=x.domain): x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain if(x_.domain==x.domain)and(x_.target!=x.target): x_.set_target(newtarget=x.target) ## ... codomain x_.set_target(new_target=x.target) ## ... codomain return x_ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ... ... @@ -3495,7 +3575,7 @@ class propagator_operator(operator): if(in_codomain)and(x.domain!=self.codomain): x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain if(x_.target!=x.target): x_.set_target(newtarget=x.target) ## ... codomain x_.set_target(new_target=x.target) ## ... codomain return x_ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment