From db57b2dec57f697eb4481dac1ed799e72810f88f Mon Sep 17 00:00:00 2001 From: Marco Selig <mselig@ncg-02.MPA-Garching.MPG.DE> Date: Wed, 11 Dec 2013 14:25:07 +0100 Subject: [PATCH] minor fixes; 'explicify' simplified. --- nifty_core.py | 22 +++++----- nifty_explicit.py | 102 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 96 insertions(+), 28 deletions(-) diff --git a/nifty_core.py b/nifty_core.py index 133683458..c644187bc 100644 --- a/nifty_core.py +++ b/nifty_core.py @@ -774,7 +774,7 @@ class random(object): """ size = np.prod(shape,axis=0,dtype=np.int,out=None) - if(datatype in [np.complex64,np.complex128]): + if(issubclass(datatype,np.complexfloating)): x = np.array([1+0j,0+1j,-1+0j,0-1j],dtype=datatype)[np.random.randint(4,high=None,size=size)] else: x = 2*np.random.randint(2,high=None,size=size)-1 @@ -816,7 +816,7 @@ class random(object): """ size = np.prod(shape,axis=0,dtype=np.int,out=None) - if(datatype in [np.complex64,np.complex128]): + if(issubclass(datatype,np.complexfloating)): x = np.empty(size,dtype=datatype,order='C') x.real = np.random.normal(loc=0,scale=np.sqrt(0.5),size=size) x.imag = np.random.normal(loc=0,scale=np.sqrt(0.5),size=size) @@ -1941,7 +1941,7 @@ class point_space(space): Number of degrees of freedom of the space. """ ## dof ~ dim - if(self.datatype in [np.complex64,np.complex128]): + if(issubclass(self.datatype,np.complexfloating)): return 2*self.para[0] else: return self.para[0] @@ -10609,7 +10609,7 @@ class probing(object): ## check codomain if(target is None): - target = domain.get_codomain() + target = self.domain.get_codomain() else: self.domain.check_codomain(target) ## a bit pointless self.target = target @@ -11087,7 +11087,7 @@ class trace_probing(probing): ## check codomain if(target is None): - target = domain.get_codomain() + target = self.domain.get_codomain() else: self.domain.check_codomain(target) ## a bit pointless self.target = target @@ -11171,7 +11171,7 @@ class trace_probing(probing): else: about.infos.cflush("\n") - if(self.domain.datatype in [np.complex64,np.complex128]): + if(issubclass(self.domain.datatype,np.complexfloating)): summa = np.real(summa) final = summa/num @@ -11219,7 +11219,7 @@ class trace_probing(probing): ## define random seed seed = np.random.randint(10**8,high=None,size=self.nrun) ## define shared objects - if(self.domain.datatype in [np.complex64,np.complex128]): + if(issubclass(self.domain.datatype,np.complexfloating)): _sum = (mv('d',0,lock=True),mv('d',0,lock=True)) ## tuple(real,imag) else: _sum = mv('d',0,lock=True) @@ -11246,7 +11246,7 @@ class trace_probing(probing): pool.join() raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping ## evaluate - if(self.domain.datatype in [np.complex64,np.complex128]): + if(issubclass(self.domain.datatype,np.complexfloating)): _sum = np.complex(_sum[0].value,_sum[1].value) else: _sum = _sum.value @@ -11478,7 +11478,7 @@ class diagonal_probing(probing): ## check codomain if(target is None): - target = domain.get_codomain() + target = self.domain.get_codomain() else: self.domain.check_codomain(target) ## a bit pointless self.target = target @@ -11646,7 +11646,7 @@ class diagonal_probing(probing): ## define random seed seed = np.random.randint(10**8,high=None,size=self.nrun) ## define shared objects - if(self.domain.datatype in [np.complex64,np.complex128]): + if(issubclass(self.domain.datatype,np.complexfloating)): _sum = (ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True),ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True)) ## tuple(real,imag) else: _sum = ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True) @@ -11673,7 +11673,7 @@ class diagonal_probing(probing): pool.join() raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping ## evaluate - if(self.domain.datatype in [np.complex64,np.complex128]): + if(issubclass(self.domain.datatype,np.complexfloating)): _sum = (np.array(_sum[0][:])+np.array(_sum[1][:])*1j).reshape(self.domain.dim(split=True)) ## comlpex array else: _sum = np.array(_sum[:]).reshape(self.domain.dim(split=True)) diff --git a/nifty_explicit.py b/nifty_explicit.py index d79daf906..7e8d4a400 100644 --- a/nifty_explicit.py +++ b/nifty_explicit.py @@ -107,8 +107,8 @@ class explicit_operator(operator): if(len(bare)!=2): raise ValueError(about._errors.cstring("ERROR: invalid input.")) else: - val = self._calc_weight_rows(val,-bool(bare[0])) - val = self._calc_weight_cols(val,-bool(bare[1])) + val = self._calc_weight_rows(val,power=-int(not bare[0])) + val = self._calc_weight_cols(val,power=-int(not bare[1])) elif(not bare): val = self._calc_weight_rows(val,-1) if(purelyreal): @@ -156,6 +156,38 @@ class explicit_operator(operator): ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + def cast_domain(self,newdomain): + """ + TODO: documentation + + """ + if(not isinstance(newdomain,space)): + raise TypeError(about._errors.cstring("ERROR: invalid input.")) + elif(newdomain.dim(split=False)!=self.domain.dim(split=False)): + raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(newdomain.dim(split=False))+" <> "+str(self.domain.dim(split=False))+" ).")) + self.domain = newdomain + + def cast_target(self,newtarget): + """ + TODO: documentation + + """ + if(not isinstance(newtarget,space)): + raise TypeError(about._errors.cstring("ERROR: invalid input.")) + elif(newtarget.dim(split=False)!=self.target.dim(split=False)): + raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(newtarget.dim(split=False))+" <> "+str(self.target.dim(split=False))+" ).")) + self.target = newtarget + + def cast_spaces(self,newdomain,newtarget): + """ + TODO: documentation + + """ + self.cast_domain(newdomain) + self.cast_target(newtarget) + + ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + def set_matrix(self,newmatrix,bare=True,sym=None,uni=None): """ TODO: documentation @@ -182,8 +214,8 @@ class explicit_operator(operator): if(len(bare)!=2): raise ValueError(about._errors.cstring("ERROR: invalid input.")) else: - val = self._calc_weight_rows(val,-bool(bare[0])) - val = self._calc_weight_cols(val,-bool(bare[1])) + val = self._calc_weight_rows(val,power=-int(not bare[0])) + val = self._calc_weight_cols(val,power=-int(not bare[1])) elif(not bare): val = self._calc_weight_rows(val,-1) if(purelyreal): @@ -209,9 +241,9 @@ class explicit_operator(operator): if(len(bare)!=2): raise ValueError(about._errors.cstring("ERROR: invalid input.")) else: - return self.weight(rowpower=bool(bare[0]),colpower=bool(bare[1]),overwrite=False) + return self.weight(rowpower=int(not bare[0]),colpower=int(not bare[1]),overwrite=False) elif(not bare): - return self.weight(rowpower=bool(bare),colpower=0,overwrite=False) + return self.weight(rowpower=int(not bare),colpower=0,overwrite=False) ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ @@ -222,10 +254,16 @@ class explicit_operator(operator): return self.target.calc_weight(x,power=power).flatten(order='C') def _calc_weight_rows(self,X,power=1): ## > weight all rows - return np.apply_along_axis(self._calc_weight_row,1,X,power) + if(np.any(np.iscomplex(X)))and(not issubclass(self.domain.datatype,np.complexfloating)): + return (np.apply_along_axis(self._calc_weight_row,1,np.real(X),power)+np.apply_along_axis(self._calc_weight_row,1,np.imag(X),power)*1j) + else: + return np.apply_along_axis(self._calc_weight_row,1,X,power) def _calc_weight_cols(self,X,power=1): ## > weight all columns - return np.apply_along_axis(self._calc_weight_col,0,X,power) + if(np.any(np.iscomplex(X)))and(not issubclass(self.target.datatype,np.complexfloating)): + return (np.apply_along_axis(self._calc_weight_col,0,np.real(X),power)+np.apply_along_axis(self._calc_weight_col,0,np.imag(X),power)*1j) + else: + return np.apply_along_axis(self._calc_weight_col,0,X,power) def weight(self,rowpower=0,colpower=0,overwrite=False): """ @@ -1448,7 +1486,7 @@ class explicit_probing(probing): # Keyword arguments passed to `function` in each call. # # """ - def __init__(self,op=None,function=None,domain=None,codomain=None,target=None,ncpu=2,nper=None,**quargs): + def __init__(self,op=None,function=None,domain=None,codomain=None,target=None,ncpu=2,nper=1,**quargs): """ TODO: documentation @@ -1489,7 +1527,7 @@ class explicit_probing(probing): # If on the other hand `nper=1`, then for each evaluation a worker will # be created. In this case all cpus will work until nrun probes have # been evaluated. -# It is recommended to leave `nper` as the default value. (default: 8) +# It is recommended to leave `nper` as the default value. (default: 1) # # """ if(op is None): @@ -1546,12 +1584,12 @@ class explicit_probing(probing): self.domain = domain self.codomain = codomain + ## check target if(target is None): - self.target = domain.get_codomain() + target = self.domain.get_codomain() else: - ## check codomain self.domain.check_codomain(target) ## a bit pointless - self.target = target + self.target = target ## check shape if(self.domain.dim(split=False)*self.codomain.dim(split=False)>1048576): @@ -1717,7 +1755,7 @@ class explicit_probing(probing): ## define weighted canonical base base = self.domain.calc_weight(self.domain.enforce_values(1,extend=True),power=-1).flatten(order='C') ## define shared objects - if(self.codomain.datatype in [np.complex64,np.complex128]): + if(issubclass(self.codomain.datatype,np.complexfloating)): _mat = (ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True),ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True)) ## tuple(real,imag) else: _mat = ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True) @@ -1740,7 +1778,7 @@ class explicit_probing(probing): pool.join() raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping ## evaluate - if(self.domain.datatype in [np.complex64,np.complex128]): + if(issubclass(self.codomain.datatype,np.complexfloating)): _mat = (np.array(_mat[0][:])+np.array(_mat[1][:])*1j).reshape((self.nrun,self.codomain.dim(split=False))) ## comlpex array else: _mat = np.array(_mat[:]).reshape((self.nrun,self.codomain.dim(split=False))) @@ -1805,12 +1843,42 @@ class explicit_probing(probing): ##----------------------------------------------------------------------------- -def explicify(operator,loop=False,**kwargs): +def explicify(operator,newdomain=None,newtarget=None,ncpu=2,nper=1,loop=False,**quargs): """ TODO: documentation """ - return explicit_probing(op=operator,**kwargs)(loop=loop) + return explicit_probing(op=operator,function=operator.times,domain=newdomain,codomain=newtarget,target=operator.domain,ncpu=ncpu,nper=nper,**quargs)(loop=loop) ##----------------------------------------------------------------------------- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + -- GitLab