Commit db57b2de authored by Marco Selig's avatar Marco Selig

minor fixes; 'explicify' simplified.

parent 47b6395f
......@@ -774,7 +774,7 @@ class random(object):
"""
size = np.prod(shape,axis=0,dtype=np.int,out=None)
if(datatype in [np.complex64,np.complex128]):
if(issubclass(datatype,np.complexfloating)):
x = np.array([1+0j,0+1j,-1+0j,0-1j],dtype=datatype)[np.random.randint(4,high=None,size=size)]
else:
x = 2*np.random.randint(2,high=None,size=size)-1
......@@ -816,7 +816,7 @@ class random(object):
"""
size = np.prod(shape,axis=0,dtype=np.int,out=None)
if(datatype in [np.complex64,np.complex128]):
if(issubclass(datatype,np.complexfloating)):
x = np.empty(size,dtype=datatype,order='C')
x.real = np.random.normal(loc=0,scale=np.sqrt(0.5),size=size)
x.imag = np.random.normal(loc=0,scale=np.sqrt(0.5),size=size)
......@@ -1941,7 +1941,7 @@ class point_space(space):
Number of degrees of freedom of the space.
"""
## dof ~ dim
if(self.datatype in [np.complex64,np.complex128]):
if(issubclass(self.datatype,np.complexfloating)):
return 2*self.para[0]
else:
return self.para[0]
......@@ -10609,7 +10609,7 @@ class probing(object):
## check codomain
if(target is None):
target = domain.get_codomain()
target = self.domain.get_codomain()
else:
self.domain.check_codomain(target) ## a bit pointless
self.target = target
......@@ -11087,7 +11087,7 @@ class trace_probing(probing):
## check codomain
if(target is None):
target = domain.get_codomain()
target = self.domain.get_codomain()
else:
self.domain.check_codomain(target) ## a bit pointless
self.target = target
......@@ -11171,7 +11171,7 @@ class trace_probing(probing):
else:
about.infos.cflush("\n")
if(self.domain.datatype in [np.complex64,np.complex128]):
if(issubclass(self.domain.datatype,np.complexfloating)):
summa = np.real(summa)
final = summa/num
......@@ -11219,7 +11219,7 @@ class trace_probing(probing):
## define random seed
seed = np.random.randint(10**8,high=None,size=self.nrun)
## define shared objects
if(self.domain.datatype in [np.complex64,np.complex128]):
if(issubclass(self.domain.datatype,np.complexfloating)):
_sum = (mv('d',0,lock=True),mv('d',0,lock=True)) ## tuple(real,imag)
else:
_sum = mv('d',0,lock=True)
......@@ -11246,7 +11246,7 @@ class trace_probing(probing):
pool.join()
raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping
## evaluate
if(self.domain.datatype in [np.complex64,np.complex128]):
if(issubclass(self.domain.datatype,np.complexfloating)):
_sum = np.complex(_sum[0].value,_sum[1].value)
else:
_sum = _sum.value
......@@ -11478,7 +11478,7 @@ class diagonal_probing(probing):
## check codomain
if(target is None):
target = domain.get_codomain()
target = self.domain.get_codomain()
else:
self.domain.check_codomain(target) ## a bit pointless
self.target = target
......@@ -11646,7 +11646,7 @@ class diagonal_probing(probing):
## define random seed
seed = np.random.randint(10**8,high=None,size=self.nrun)
## define shared objects
if(self.domain.datatype in [np.complex64,np.complex128]):
if(issubclass(self.domain.datatype,np.complexfloating)):
_sum = (ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True),ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True)) ## tuple(real,imag)
else:
_sum = ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True)
......@@ -11673,7 +11673,7 @@ class diagonal_probing(probing):
pool.join()
raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping
## evaluate
if(self.domain.datatype in [np.complex64,np.complex128]):
if(issubclass(self.domain.datatype,np.complexfloating)):
_sum = (np.array(_sum[0][:])+np.array(_sum[1][:])*1j).reshape(self.domain.dim(split=True)) ## comlpex array
else:
_sum = np.array(_sum[:]).reshape(self.domain.dim(split=True))
......
......@@ -107,8 +107,8 @@ class explicit_operator(operator):
if(len(bare)!=2):
raise ValueError(about._errors.cstring("ERROR: invalid input."))
else:
val = self._calc_weight_rows(val,-bool(bare[0]))
val = self._calc_weight_cols(val,-bool(bare[1]))
val = self._calc_weight_rows(val,power=-int(not bare[0]))
val = self._calc_weight_cols(val,power=-int(not bare[1]))
elif(not bare):
val = self._calc_weight_rows(val,-1)
if(purelyreal):
......@@ -156,6 +156,38 @@ class explicit_operator(operator):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast_domain(self,newdomain):
"""
TODO: documentation
"""
if(not isinstance(newdomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
elif(newdomain.dim(split=False)!=self.domain.dim(split=False)):
raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(newdomain.dim(split=False))+" <> "+str(self.domain.dim(split=False))+" )."))
self.domain = newdomain
def cast_target(self,newtarget):
"""
TODO: documentation
"""
if(not isinstance(newtarget,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
elif(newtarget.dim(split=False)!=self.target.dim(split=False)):
raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(newtarget.dim(split=False))+" <> "+str(self.target.dim(split=False))+" )."))
self.target = newtarget
def cast_spaces(self,newdomain,newtarget):
"""
TODO: documentation
"""
self.cast_domain(newdomain)
self.cast_target(newtarget)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_matrix(self,newmatrix,bare=True,sym=None,uni=None):
"""
TODO: documentation
......@@ -182,8 +214,8 @@ class explicit_operator(operator):
if(len(bare)!=2):
raise ValueError(about._errors.cstring("ERROR: invalid input."))
else:
val = self._calc_weight_rows(val,-bool(bare[0]))
val = self._calc_weight_cols(val,-bool(bare[1]))
val = self._calc_weight_rows(val,power=-int(not bare[0]))
val = self._calc_weight_cols(val,power=-int(not bare[1]))
elif(not bare):
val = self._calc_weight_rows(val,-1)
if(purelyreal):
......@@ -209,9 +241,9 @@ class explicit_operator(operator):
if(len(bare)!=2):
raise ValueError(about._errors.cstring("ERROR: invalid input."))
else:
return self.weight(rowpower=bool(bare[0]),colpower=bool(bare[1]),overwrite=False)
return self.weight(rowpower=int(not bare[0]),colpower=int(not bare[1]),overwrite=False)
elif(not bare):
return self.weight(rowpower=bool(bare),colpower=0,overwrite=False)
return self.weight(rowpower=int(not bare),colpower=0,overwrite=False)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -222,10 +254,16 @@ class explicit_operator(operator):
return self.target.calc_weight(x,power=power).flatten(order='C')
def _calc_weight_rows(self,X,power=1): ## > weight all rows
return np.apply_along_axis(self._calc_weight_row,1,X,power)
if(np.any(np.iscomplex(X)))and(not issubclass(self.domain.datatype,np.complexfloating)):
return (np.apply_along_axis(self._calc_weight_row,1,np.real(X),power)+np.apply_along_axis(self._calc_weight_row,1,np.imag(X),power)*1j)
else:
return np.apply_along_axis(self._calc_weight_row,1,X,power)
def _calc_weight_cols(self,X,power=1): ## > weight all columns
return np.apply_along_axis(self._calc_weight_col,0,X,power)
if(np.any(np.iscomplex(X)))and(not issubclass(self.target.datatype,np.complexfloating)):
return (np.apply_along_axis(self._calc_weight_col,0,np.real(X),power)+np.apply_along_axis(self._calc_weight_col,0,np.imag(X),power)*1j)
else:
return np.apply_along_axis(self._calc_weight_col,0,X,power)
def weight(self,rowpower=0,colpower=0,overwrite=False):
"""
......@@ -1448,7 +1486,7 @@ class explicit_probing(probing):
# Keyword arguments passed to `function` in each call.
#
# """
def __init__(self,op=None,function=None,domain=None,codomain=None,target=None,ncpu=2,nper=None,**quargs):
def __init__(self,op=None,function=None,domain=None,codomain=None,target=None,ncpu=2,nper=1,**quargs):
"""
TODO: documentation
......@@ -1489,7 +1527,7 @@ class explicit_probing(probing):
# If on the other hand `nper=1`, then for each evaluation a worker will
# be created. In this case all cpus will work until nrun probes have
# been evaluated.
# It is recommended to leave `nper` as the default value. (default: 8)
# It is recommended to leave `nper` as the default value. (default: 1)
#
# """
if(op is None):
......@@ -1546,12 +1584,12 @@ class explicit_probing(probing):
self.domain = domain
self.codomain = codomain
## check target
if(target is None):
self.target = domain.get_codomain()
target = self.domain.get_codomain()
else:
## check codomain
self.domain.check_codomain(target) ## a bit pointless
self.target = target
self.target = target
## check shape
if(self.domain.dim(split=False)*self.codomain.dim(split=False)>1048576):
......@@ -1717,7 +1755,7 @@ class explicit_probing(probing):
## define weighted canonical base
base = self.domain.calc_weight(self.domain.enforce_values(1,extend=True),power=-1).flatten(order='C')
## define shared objects
if(self.codomain.datatype in [np.complex64,np.complex128]):
if(issubclass(self.codomain.datatype,np.complexfloating)):
_mat = (ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True),ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True)) ## tuple(real,imag)
else:
_mat = ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True)
......@@ -1740,7 +1778,7 @@ class explicit_probing(probing):
pool.join()
raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping
## evaluate
if(self.domain.datatype in [np.complex64,np.complex128]):
if(issubclass(self.codomain.datatype,np.complexfloating)):
_mat = (np.array(_mat[0][:])+np.array(_mat[1][:])*1j).reshape((self.nrun,self.codomain.dim(split=False))) ## comlpex array
else:
_mat = np.array(_mat[:]).reshape((self.nrun,self.codomain.dim(split=False)))
......@@ -1805,12 +1843,42 @@ class explicit_probing(probing):
##-----------------------------------------------------------------------------
def explicify(operator,loop=False,**kwargs):
def explicify(operator,newdomain=None,newtarget=None,ncpu=2,nper=1,loop=False,**quargs):
"""
TODO: documentation
"""
return explicit_probing(op=operator,**kwargs)(loop=loop)
return explicit_probing(op=operator,function=operator.times,domain=newdomain,codomain=newtarget,target=operator.domain,ncpu=ncpu,nper=nper,**quargs)(loop=loop)
##-----------------------------------------------------------------------------
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment