Commit 4a3fa523 authored by Ultima's avatar Ultima
Browse files

Finished rework of probing classes.

Finished operator and diagonal_operator.
parent 988981a0
...@@ -18,6 +18,7 @@ operators/* ...@@ -18,6 +18,7 @@ operators/*
!operators/nifty_explicit.py !operators/nifty_explicit.py
!operators/nifty_operators.py !operators/nifty_operators.py
!operators/nifty_probing.py !operators/nifty_probing.py
!operators/nifty_probing_old.py
rg/* rg/*
......
...@@ -912,17 +912,14 @@ class space(object): ...@@ -912,17 +912,14 @@ class space(object):
else: else:
return mars return mars
def __eq__(self,x): ## __eq__ : self == x def __eq__(self, x): ## __eq__ : self == x
if isinstance(x, type(self)): if isinstance(x, type(self)):
return self._identifier() == x._identifier() return self._identifier() == x._identifier()
else: else:
return False return False
def __ne__(self,x): ## __ne__ : self <> x def __ne__(self, x):
if(isinstance(x,space)): return not self.__eq__(x)
if(not isinstance(x,type(self)))or(np.any(self.para!=x.para))or(self.discrete!=x.discrete)or(np.any(self.vol!=x.vol))or(np.any(self._meta_vars()!=x._meta_vars())): ## data types are ignored
return True
return False
def __lt__(self,x): ## __lt__ : self < x def __lt__(self,x): ## __lt__ : self < x
if(isinstance(x,space)): if(isinstance(x,space)):
...@@ -1165,6 +1162,8 @@ class point_space(space): ...@@ -1165,6 +1162,8 @@ class point_space(space):
"argmax" : _argmax, "argmax" : _argmax,
"argmax_flat" : np.argmax, "argmax_flat" : np.argmax,
"conjugate" : np.conjugate, "conjugate" : np.conjugate,
"sum" : np.sum,
"prod" : np.prod,
"None" : lambda y: y} "None" : lambda y: y}
...@@ -2648,7 +2647,7 @@ class field(object): ...@@ -2648,7 +2647,7 @@ class field(object):
The space wherein the operator output lives (default: domain). The space wherein the operator output lives (default: domain).
""" """
def __init__(self,domain,val=None,target=None,**kwargs): def __init__(self, domain, val=None, target=None, **kwargs):
""" """
Sets the attributes for a field class instance. Sets the attributes for a field class instance.
...@@ -2878,6 +2877,7 @@ class field(object): ...@@ -2878,6 +2877,7 @@ class field(object):
new_field.set_val(new_val = self.domain.calc_weight(self.get_val(), new_field.set_val(new_val = self.domain.calc_weight(self.get_val(),
power = power)) power = power))
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
...@@ -54,7 +54,7 @@ from nifty_operators import operator,\ ...@@ -54,7 +54,7 @@ from nifty_operators import operator,\
diagonal_operator,\ diagonal_operator,\
identity,\ identity,\
vecvec_operator vecvec_operator
from nifty_probing import probing from nifty_probing_old import probing
##----------------------------------------------------------------------------- ##-----------------------------------------------------------------------------
......
...@@ -27,9 +27,13 @@ from nifty.nifty_core import space, \ ...@@ -27,9 +27,13 @@ from nifty.nifty_core import space, \
nested_space, \ nested_space, \
field field
from nifty_minimization import conjugate_gradient from nifty_minimization import conjugate_gradient
from nifty_probing import trace_probing, \ from nifty_probing import trace_prober,\
diagonal_probing inverse_trace_prober,\
from nifty_mpi_probing import prober diagonal_prober,\
inverse_diagonal_prober
import nifty_simple_math
##============================================================================= ##=============================================================================
...@@ -104,7 +108,7 @@ class operator(object): ...@@ -104,7 +108,7 @@ class operator(object):
operator class can use. Not used in the base operators. operator class can use. Not used in the base operators.
""" """
def __init__(self, domain, sym=False, uni=False, imp=False, target=None,\ def __init__(self, domain, sym=False, uni=False, imp=False, target=None,\
para=None): bare = False, para=None):
""" """
Sets the attributes for an operator class instance. Sets the attributes for an operator class instance.
...@@ -141,6 +145,7 @@ class operator(object): ...@@ -141,6 +145,7 @@ class operator(object):
## Cast the symmetric and unitary input ## Cast the symmetric and unitary input
self.sym = bool(sym) self.sym = bool(sym)
self.uni = bool(uni) self.uni = bool(uni)
self.bare = bool(bare)
## If no target is supplied, we assume that the operator is square ## If no target is supplied, we assume that the operator is square
## If the operator is symmetric or unitary, we know that the operator ## If the operator is symmetric or unitary, we know that the operator
...@@ -477,8 +482,8 @@ class operator(object): ...@@ -477,8 +482,8 @@ class operator(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def tr(self, domain=None, target=None, random="pm1", ncpu=2, nrun=8,\ def tr(self, domain=None, codomain=None, random="pm1", nrun=8,
nper=1, var=False, loop=False, **kwargs): varQ=False, **kwargs):
""" """
Computes the trace of the operator Computes the trace of the operator
...@@ -520,21 +525,19 @@ class operator(object): ...@@ -520,21 +525,19 @@ class operator(object):
-------- --------
probing : The class used to perform probing operations probing : The class used to perform probing operations
""" """
if domain is None:
domain = self.domain return trace_prober(operator = self,
return trace_probing(self, domain = domain,
function=self.times, codomain = codomain,
domain=domain, random = random,
target=target, nrun = nrun,
random=random, varQ = varQ,
ncpu=(ncpu,1)[bool(loop)], **kwargs
nrun=nrun, )()
nper=nper,
var=var,
**kwargs)(loop=loop) def inverse_tr(self, domain=None, codomain=None, random="pm1", nrun=8,
varQ=False, **kwargs):
def inverse_tr(self, domain=None, target=None, random="pm1", ncpu=2,
nrun=8, nper=1, var=False, loop=False, **kwargs):
""" """
Computes the trace of the inverse operator Computes the trace of the inverse operator
...@@ -551,18 +554,12 @@ class operator(object): ...@@ -551,18 +554,12 @@ class operator(object):
for a random vector with entries drawn from a Gaussian for a random vector with entries drawn from a Gaussian
distribution with zero mean and unit variance. distribution with zero mean and unit variance.
(default: "pm1") (default: "pm1")
ncpu : int, *optional*
number of used CPUs to use. (default: 2)
nrun : int, *optional* nrun : int, *optional*
total number of probes (default: 8) total number of probes (default: 8)
nper : int, *optional* varQ : bool, *optional*
number of tasks performed by one process (default: 1)
var : bool, *optional*
Indicates whether to additionally return the probing variance Indicates whether to additionally return the probing variance
or not (default: False). or not (default: False).
loop : bool, *optional*
Indicates whether or not to perform a loop i.e., to
parallelise (default: False)
Returns Returns
------- -------
...@@ -576,23 +573,19 @@ class operator(object): ...@@ -576,23 +573,19 @@ class operator(object):
-------- --------
probing : The class used to perform probing operations probing : The class used to perform probing operations
""" """
if(domain is None): return inverse_trace_prober(operator = self,
domain = self.target domain = domain,
return trace_probing(self, codomain = codomain,
function=self.inverse_times, random = random,
domain=domain, nrun = nrun,
target=target, varQ = varQ,
random=random, **kwargs
ncpu=(ncpu,1)[bool(loop)], )()
nrun=nrun,
nper=nper,
var=var, **kwargs)(loop=loop)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def diag(self, bare=False, domain=None, target=None, random="pm1", ncpu=2, def diag(self, domain=None, codomain=None, random="pm1", nrun=8,
nrun=8, nper=1, var=False, save=False, path="tmp", prefix="", varQ=False, bare=False, **kwargs):
loop=False, **kwargs):
""" """
Computes the diagonal of the operator via probing. Computes the diagonal of the operator via probing.
...@@ -657,37 +650,33 @@ class operator(object): ...@@ -657,37 +650,33 @@ class operator(object):
entries; e.g., as variance in case of an covariance operator. entries; e.g., as variance in case of an covariance operator.
""" """
if(domain is None):
domain = self.domain diag = diagonal_prober(operator = self,
diag = diagonal_probing(self, domain = domain,
function=self.times, codomain = codomain,
domain=domain, random = random,
target=target, nrun = nrun,
random=random, varQ = varQ,
ncpu=(ncpu,1)[bool(loop)], **kwargs
nrun=nrun, )()
nper=nper,
var=var,
save=save,
path=path,
prefix=prefix,
**kwargs)(loop=loop)
if diag is None: if diag is None:
# about.warnings.cprint("WARNING: forwarding 'NoneType'.") about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None return None
if domain is None:
domain = self.domain
## weight if ... ## weight if ...
elif domain.discrete == False and bare == True: if domain.discrete == False and bare == True:
if isinstance(diag, tuple): ## diag == (diag,variance) if(isinstance(diag,tuple)): ## diag == (diag,variance)
return (domain.calc_weight(diag[0],power=-1), return (diag[0].weight(power=-1),
domain.calc_weight(diag[1],power=-1)) diag[1].weight(power=-1))
else: else:
return domain.calc_weight(diag,power=-1) return diag.weight(power=-1)
else: else:
return diag return diag
def inverse_diag(self, bare=False, domain=None, target=None, random="pm1", def inverse_diag(self, domain=None, codomain=None, random="pm1", nrun=8,
ncpu=2, nrun=8, nper=1, var=False, save=False, path="tmp", varQ=False, bare=False, **kwargs):
prefix="", loop=False, **kwargs):
""" """
Computes the diagonal of the inverse operator via probing. Computes the diagonal of the inverse operator via probing.
...@@ -754,32 +743,31 @@ class operator(object): ...@@ -754,32 +743,31 @@ class operator(object):
""" """
if(domain is None): if(domain is None):
domain = self.target domain = self.target
diag = diagonal_probing(self, diag = inverse_diagonal_prober(operator = self,
function=self.inverse_times, domain = domain,
domain=domain, codomain = codomain,
target=target, random = random,
random=random, nrun = nrun,
ncpu=(ncpu,1)[bool(loop)], varQ = varQ,
nrun=nrun, **kwargs
nper=nper, )()
var=var,
save=save,
path=path,
prefix=prefix,
**kwargs)(loop=loop)
if(diag is None): if(diag is None):
# about.warnings.cprint("WARNING: forwarding 'NoneType'.") # about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None return None
if domain is None:
domain = self.target
## weight if ... ## weight if ...
elif(not domain.discrete)and(bare): if domain.discrete == False and bare == True:
if(isinstance(diag,tuple)): ## diag == (diag,variance) if(isinstance(diag,tuple)): ## diag == (diag,variance)
return (domain.calc_weight(diag[0],power=-1), return (diag[0].weight(power=-1),
domain.calc_weight(diag[1],power=-1)) diag[1].weight(power=-1))
else: else:
return domain.calc_weight(diag,power=-1) return diag.weight(power=-1)
else: else:
return diag return diag
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def det(self): def det(self):
...@@ -1217,6 +1205,7 @@ class diagonal_operator(operator): ...@@ -1217,6 +1205,7 @@ class diagonal_operator(operator):
else: else:
self.domain = domain self.domain = domain
self.target = self.domain self.target = self.domain
self.imp = True self.imp = True
self.set_diag(new_diag = diag) self.set_diag(new_diag = diag)
...@@ -1416,7 +1405,7 @@ class diagonal_operator(operator): ...@@ -1416,7 +1405,7 @@ class diagonal_operator(operator):
""" """
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def tr(self,domain=None,**kwargs): def tr(self, varQ=False, **kwargs):
""" """
Computes the trace of the operator Computes the trace of the operator
...@@ -1454,6 +1443,15 @@ class diagonal_operator(operator): ...@@ -1454,6 +1443,15 @@ class diagonal_operator(operator):
Probing variance of the trace. Returned if `var` is True in Probing variance of the trace. Returned if `var` is True in
of probing case. of probing case.
"""
tr = self.domain.unary_operation(self.val, 'sum')
if varQ == True:
return (tr, 1)
else:
return tr
""" """
if(domain is None)or(domain==self.domain): if(domain is None)or(domain==self.domain):
if(self.uni): ## identity if(self.uni): ## identity
...@@ -1473,7 +1471,8 @@ class diagonal_operator(operator): ...@@ -1473,7 +1471,8 @@ class diagonal_operator(operator):
else: else:
return super(diagonal_operator,self).tr(domain=domain,**kwargs) ## probing return super(diagonal_operator,self).tr(domain=domain,**kwargs) ## probing
def inverse_tr(self,domain=None,**kwargs): """
def inverse_tr(self, varQ=False, **kwargs):
""" """
Computes the trace of the inverse operator Computes the trace of the inverse operator
...@@ -1512,6 +1511,19 @@ class diagonal_operator(operator): ...@@ -1512,6 +1511,19 @@ class diagonal_operator(operator):
of probing case. of probing case.
""" """
if (self.get_val() == 0).any():
raise AttributeError(about._errors.cstring(
"ERROR: singular operator."))
inverse_tr = self.domain.unary_operation(
self.domain.binary_operation(self.val, 1, 'rdiv', cast=0),
'sum')
if varQ == True:
return (inverse_tr, 1)
else:
return inverse_tr
"""
if(np.any(self.val==0)): if(np.any(self.val==0)):
raise AttributeError(about._errors.cstring("ERROR: singular operator.")) raise AttributeError(about._errors.cstring("ERROR: singular operator."))
...@@ -1521,7 +1533,7 @@ class diagonal_operator(operator): ...@@ -1521,7 +1533,7 @@ class diagonal_operator(operator):
elif(self.domain.dim(split=False)<self.domain.dof()): ## hidden degrees of freedom elif(self.domain.dim(split=False)<self.domain.dof()): ## hidden degrees of freedom
return self.domain.calc_dot(np.ones(self.domain.dim(split=True),dtype=self.domain.datatype,order='C'),1/self.val) ## discrete inner product return self.domain.calc_dot(np.ones(self.domain.dim(split=True),dtype=self.domain.datatype,order='C'),1/self.val) ## discrete inner product
else: else:
return np.sum(1/self.val,axis=None,dtype=None,out=None) return np.sum(1./self.val,axis=None,dtype=None,out=None)
else: else:
if(self.uni): ## identity if(self.uni): ## identity
if(not isinstance(domain,space)): if(not isinstance(domain,space)):
...@@ -1532,10 +1544,11 @@ class diagonal_operator(operator): ...@@ -1532,10 +1544,11 @@ class diagonal_operator(operator):
return np.real(domain.datatype(domain.dof())) return np.real(domain.datatype(domain.dof()))
else: else:
return super(diagonal_operator,self).inverse_tr(domain=domain,**kwargs) ## probing return super(diagonal_operator,self).inverse_tr(domain=domain,**kwargs) ## probing
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def diag(self,bare=False,domain=None,**kwargs): def diag(self, bare=False, domain=None, varQ=False, **kwargs):
""" """
Computes the diagonal of the operator. Computes the diagonal of the operator.
...@@ -1596,6 +1609,21 @@ class diagonal_operator(operator): ...@@ -1596,6 +1609,21 @@ class diagonal_operator(operator):
entries; e.g., as variance in case of an covariance operator. entries; e.g., as variance in case of an covariance operator.
""" """
diag = super(diagonal_operator, self).diag(bare=bare,
domain=domain,
nrun=1,
random='pm1',
varQ=False,
**kwargs)
if varQ == True:
return (diag, diag.domain.cast(1))
else:
return diag
"""
if(domain is None)or(domain==self.domain): if(domain is None)or(domain==self.domain):
## weight if ... ## weight if ...
if(not self.domain.discrete)and(bare): if(not self.domain.discrete)and(bare):
...@@ -1618,7 +1646,8 @@ class diagonal_operator(operator): ...@@ -1618,7 +1646,8 @@ class diagonal_operator(operator):
else: else:
return super(diagonal_operator,self).diag(bare=bare,domain=domain,**kwargs) ## probing return super(diagonal_operator,self).diag(bare=bare,domain=domain,**kwargs) ## probing
def inverse_diag(self,bare=False,domain=None,**kwargs): """
def inverse_diag(self,bare=False,domain=None, varQ=False, **kwargs):
""" """
Computes the diagonal of the inverse operator. Computes the diagonal of the inverse operator.
...@@ -1683,6 +1712,19 @@ class diagonal_operator(operator): ...@@ -1683,6 +1712,19 @@ class diagonal_operator(operator):
entries; e.g., as variance in case of an covariance operator. entries; e.g., as variance in case of an covariance operator.
""" """
inverse_diag = super(diagonal_operator, self).inverse_diag(bare=bare,
domain=domain,
nrun=1,
random='pm1',
varQ=False,
**kwargs)
if varQ == True:
return (inverse_diag, inverse_diag.domain.cast(1))
else:
return inverse_diag
"""
if(domain is None)or(domain==self.target): if(domain is None)or(domain==self.target):
## weight if ... ## weight if ...
if(not self.domain.discrete)and(bare): if(not self.domain.discrete)and(bare):
...@@ -1704,7 +1746,8 @@ class diagonal_operator(operator): ...@@ -1704,7 +1746,8 @@ class diagonal_operator(operator):
return np.real(domain.enforce_values(1,extend=True)) return np.real(domain.enforce_values(1,extend=True))
else: else:
return super(diagonal_operator,self).inverse_diag(bare=bare,domain=domain,**kwargs) ## probing return super(diagonal_operator,self).inverse_diag(bare=bare,domain=domain,**kwargs) ## probing
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def det(self): def det(self):
...@@ -1717,12 +1760,14 @@ class diagonal_operator(operator): ...@@ -1717,12 +1760,14 @@ class diagonal_operator(operator):
The determinant The determinant
""" """
if(self.uni): ## identity if self.uni == True: ## identity
return 1 return 1.
elif(self.domain.dim(split=False)<self.domain.dof()): ## hidden degrees of freedom
return np.exp(self.domain.calc_dot(np.ones(self.domain.dim(split=True),dtype=self.domain.datatype,order='C'),np.log(self.val)))
else: else:
return np.prod(self.val,axis=None,dtype=None,out=None) return self.domain.unary_operation(self.val, op='prod')
#elif(self.domain.dim(split=False)<self.domain.dof()): ## hidden degrees of freedom
# return np.exp(self.domain.calc_dot(np.ones(self.domain.dim(split=True),dtype=self.domain.datatype,order='C'),np.log(self.val)))
#else:
# return np.prod(self.val,axis=None,dtype=None,out=None)
def inverse_det(self): def inverse_det(self):
""" """
...@@ -1734,11 +1779,13 @@ class diagonal_operator(operator): ...@@ -1734,11 +1779,13 @@ class diagonal_operator(operator):
The determinant The determinant
""" """
if(self.uni): ## identity if self.uni == True: ## identity
return 1 return 1.
det = self.det() det = self.det()