Commit 148e6fbc authored by Ultima's avatar Ultima
Browse files

Started rework of operators. operator finished. diagonal_operator started.

parent df30d59c
......@@ -1728,6 +1728,11 @@ class point_space(space):
"""
raise AttributeError(about._errors.cstring("ERROR: power spectra ill-defined for (unstructured) point space."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_real_Q(self, x):
return np.all(np.isreal(x))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def get_plot(self,x,title="",vmin=None,vmax=None,unit="",norm=None,other=None,legend=False,**kwargs):
......
......@@ -23,8 +23,9 @@
##initialize the 'found-packages'-dictionary
found = {}
found = {}
import numpy as np
from nifty_about import about
......@@ -49,8 +50,6 @@ except(ImportError):
found['h5py'] = False
found['h5py_parallel'] = False
class distributed_data_object(object):
"""
......@@ -217,7 +216,7 @@ class distributed_data_object(object):
return False
## Case 4: 'other' is something different
## -> make a numpy casting and make a recursion
## -> make a numpy casting and make a recursive call
else:
temp_other = np.array(other)
return self.__eq__(temp_other)
......@@ -543,10 +542,22 @@ class distributed_data_object(object):
return temp_d2o
def is_completely_real(self):
local_realiness = np.all(self.isreal())
local_realiness = np.all(self.isreal().get_local_data())
global_realiness = self.distributor._allgather(local_realiness)
return np.all(global_realiness)
def all(self):
local_all = np.all(self.get_local_data())
global_all = self.distributor._allgather(local_all)
return np.all(global_all)
def any(self):
local_any = np.any(self.get_local_data())
global_any = self.distributor._allgather(local_any)
return np.all(global_any)
def set_local_data(self, data, hermitian=False, copy=False):
"""
Stores data directly in the local data attribute. No distribution
......@@ -1464,10 +1475,12 @@ class _not_distributor(object):
return np.array(data).astype(self.dtype, copy=False).\
reshape(self.global_shape)
def disperse_data(self, data, data_update, key, **kwargs):
data[key] = np.array(data_update, copy=False).astype(self.dtype)
def disperse_data(self, data, to_slices, data_update, from_slices=None,
**kwargs):
data[to_slices] = np.array(data_update[from_slices],
copy=False).astype(self.dtype)
def collect_data(self, data, slice_objects, **kwargs):
def collect_data(self, data, slice_objects, **kwargs):
return data[slice_objects]
def consolidate_data(self, data, **kwargs):
......
......@@ -102,7 +102,8 @@ class operator(object):
This is a freeform list of parameters that derivatives of the
operator class can use. Not used in the base operators.
"""
def __init__(self,domain,sym=False,uni=False,imp=False,target=None,para=None):
def __init__(self, domain, sym=False, uni=False, imp=False, target=None,\
para=None):
"""
Sets the attributes for an operator class instance.
......@@ -131,25 +132,38 @@ class operator(object):
-------
None
"""
if(not isinstance(domain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
## Check if the domain is realy a space
if not isinstance(domain,space):
raise TypeError(about._errors.cstring(
"ERROR: invalid input. domain is not a space."))
self.domain = domain
## Cast the symmetric and unitary input
self.sym = bool(sym)
self.uni = bool(uni)
if(target is None)or(self.sym)or(self.uni):
## If no target is supplied, we assume that the operator is square
## If the operator is symmetric or unitary, we know that the operator
## must be square
if self.sym == True or self.uni == True:
target = self.domain
if target is not None:
about.warnings.cprint("WARNING: Ignoring target.")
elif target is None:
target = self.domain
elif(not isinstance(target,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
elif not isinstance(target, space):
raise TypeError(about._errors.cstring(
"ERROR: invalid input. Target is not a space."))
self.target = target
if(self.domain.discrete)and(self.target.discrete):
if self.domain.discrete and self.target.discrete:
self.imp = True
else:
self.imp = bool(imp)
if(para is not None):
self.para = para
self.para = para
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -175,7 +189,7 @@ class operator(object):
"""
return self.domain.dim(split=False)
def dim(self,axis=None):
def dim(self, axis=None):
"""
Computes the dimension of the space
......@@ -191,14 +205,15 @@ class operator(object):
The dimension(s) of the operator.
"""
if(axis is None):
if axis is None:
return np.array([self.nrow(),self.ncol()])
elif(axis==0):
elif axis == 0:
return self.nrow()
elif(axis==1):
elif axis == 1:
return self.ncol()
else:
raise ValueError(about._errors.cstring("ERROR: invalid input axis."))
raise ValueError(about._errors.cstring(
"ERROR: invalid input axis."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -221,63 +236,77 @@ class operator(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _multiply(self,x,**kwargs): ## > applies the operator to a given field
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'multiply'."))
def _multiply(self, x, **kwargs):
## > applies the operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'multiply'."))
def _adjoint_multiply(self,x,**kwargs): ## > applies the adjoint operator to a given field
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'adjoint_multiply'."))
def _adjoint_multiply(self, x, **kwargs):
## > applies the adjoint operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_multiply'."))
def _inverse_multiply(self,x,**kwargs): ## > applies the inverse operator to a given field
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'inverse_multiply'."))
def _inverse_multiply(self, x, **kwargs):
## > applies the inverse operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_multiply'."))
def _adjoint_inverse_multiply(self,x,**kwargs): ## > applies the inverse adjoint operator to a given field
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'adjoint_inverse_multiply'."))
def _adjoint_inverse_multiply(self, x, **kwargs):
## > applies the inverse adjoint operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_inverse_multiply'."))
def _inverse_adjoint_multiply(self,x,**kwargs): ## > applies the adjoint inverse operator to a given field
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'inverse_adjoint_multiply'."))
def _inverse_adjoint_multiply(self, x, **kwargs):
## > applies the adjoint inverse operator to a given field
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_adjoint_multiply'."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _briefing(self,x,domain,inverse): ## > prepares x for `multiply`
def _briefing(self, x, domain, inverse): ## > prepares x for `multiply`
## inspect x
if(not isinstance(x,field)):
x_ = field(domain,val=x,target=None)
if not isinstance(x, field):
x_ = field(domain, val=x)
else:
## check x.domain
if(domain==x.domain):
if domain == x.domain:
x_ = x
## transform
else:
x_ = x.transform(target=domain,overwrite=False)
x_ = x.transform(target=domain)
## weight if ...
if(not self.imp)and(not domain.discrete)and(not inverse):
x_ = x_.weight(power=1,overwrite=False)
if self.imp == False and domain.discrete == False and inverse == False:
x_ = x_.weight(power=1)
return x_
def _debriefing(self,x,x_,target,inverse): ## > evaluates x and x_ after `multiply`
if(x_ is None):
def _debriefing(self, x, x_, target, inverse):
## > evaluates x and x_ after `multiply`
if x_ is None:
return None
else:
## inspect x_
if(not isinstance(x_,field)):
x_ = field(target,val=x_,target=None)
elif(x_.domain!=target):
raise ValueError(about._errors.cstring("ERROR: invalid output domain."))
if not isinstance(x_, field):
x_ = field(target, val=x_)
elif x_.domain != target:
raise ValueError(about._errors.cstring(
"ERROR: invalid output domain."))
## weight if ...
if(not self.imp)and(not target.discrete)and(inverse):
x_ = x_.weight(power=-1,overwrite=False)
if self.imp == False and target.discrete == False and\
inverse == True:
x_ = x_.weight(power=-1)
## inspect x
if(isinstance(x,field)):
## repair ...
if(self.domain==self.target!=x.domain):
x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain
if(x_.domain==x.domain)and(x_.target!=x.target):
x_.set_target(newtarget=x.target) ## ... codomain
if isinstance(x, field):
## repair if the originally field was living in the codomain
## of the operators domain
if self.domain == self.target != x.domain:
x_ = x_.transform(target=x.domain)
if x_.domain == x.domain and (x_.target is not x.target):
x_.set_target(newtarget = x.target)
return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def times(self,x,**kwargs):
def times(self, x, **kwargs):
"""
Applies the operator to a given object
......@@ -293,16 +322,16 @@ class operator(object):
Mapped field on the target domain of the operator.
"""
## prepare
x_ = self._briefing(x,self.domain,False)
x_ = self._briefing(x, self.domain, inverse=False)
## apply operator
x_ = self._multiply(x_,**kwargs)
x_ = self._multiply(x_, **kwargs)
## evaluate
return self._debriefing(x,x_,self.target,False)
return self._debriefing(x, x_, self.target, inverse=False)
def __call__(self,x,**kwargs):
return self.times(x,**kwargs)
def __call__(self, x, **kwargs):
return self.times(x, **kwargs)
def adjoint_times(self,x,**kwargs):
def adjoint_times(self, x, **kwargs):
"""
Applies the adjoint operator to a given object.
......@@ -319,20 +348,20 @@ class operator(object):
"""
## check whether self-adjoint
if(self.sym):
return self.times(x,**kwargs)
if self.sym == True:
return self.times(x, **kwargs)
## check whether unitary
if(self.uni):
return self.inverse_times(x,**kwargs)
if self.uni == True:
return self.inverse_times(x, **kwargs)
## prepare
x_ = self._briefing(x,self.target,False)
x_ = self._briefing(x, self.target, inverse=False)
## apply operator
x_ = self._adjoint_multiply(x_,**kwargs)
x_ = self._adjoint_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x,x_,self.domain,False)
return self._debriefing(x, x_, self.domain, inverse=False)
def inverse_times(self,x,**kwargs):
def inverse_times(self, x, **kwargs):
"""
Applies the inverse operator to a given object.
......@@ -348,17 +377,17 @@ class operator(object):
Mapped field on the target space of the operator.
"""
## check whether self-inverse
if(self.sym)and(self.uni):
if self.sym == True and self.uni == True:
return self.times(x,**kwargs)
## prepare
x_ = self._briefing(x,self.target,True)
x_ = self._briefing(x, self.target, inverse=True)
## apply operator
x_ = self._inverse_multiply(x_,**kwargs)
x_ = self._inverse_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x,x_,self.domain,True)
return self._debriefing(x, x_, self.domain, inverse=True)
def adjoint_inverse_times(self,x,**kwargs):
def adjoint_inverse_times(self, x, **kwargs):
"""
Applies the inverse adjoint operator to a given object.
......@@ -375,20 +404,20 @@ class operator(object):
"""
## check whether self-adjoint
if(self.sym):
return self.inverse_times(x,**kwargs)
if self.sym == True:
return self.inverse_times(x, **kwargs)
## check whether unitary
if(self.uni):
return self.times(x,**kwargs)
if self.uni == True:
return self.times(x, **kwargs)
## prepare
x_ = self._briefing(x,self.domain,True)
x_ = self._briefing(x, self.domain, inverse=True)
## apply operator
x_ = self._adjoint_inverse_multiply(x_,**kwargs)
x_ = self._adjoint_inverse_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x,x_,self.target,True)
return self._debriefing(x, x_, self.target, inverse=True)
def inverse_adjoint_times(self,x,**kwargs):
def inverse_adjoint_times(self, x, **kwargs):
"""
Applies the adjoint inverse operator to a given object.
......@@ -405,22 +434,23 @@ class operator(object):
"""
## check whether self-adjoint
if(self.sym):
return self.inverse_times(x,**kwargs)
if self.sym == True:
return self.inverse_times(x, **kwargs)
## check whether unitary
if(self.uni):
return self.times(x,**kwargs)
if self.uni == True:
return self.times(x, **kwargs)
## prepare
x_ = self._briefing(x,self.domain,True)
x_ = self._briefing(x, self.domain, inverse=True)
## apply operator
x_ = self._inverse_adjoint_multiply(x_,**kwargs)
x_ = self._inverse_adjoint_multiply(x_, **kwargs)
## evaluate
return self._debriefing(x,x_,self.target,True)
return self._debriefing(x, x_, self.target, inverse=True)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def tr(self,domain=None,target=None,random="pm1",ncpu=2,nrun=8,nper=1,var=False,loop=False,**kwargs):
def tr(self, domain=None, target=None, random="pm1", ncpu=2, nrun=8,\
nper=1, var=False, loop=False, **kwargs):
"""
Computes the trace of the operator
......@@ -462,11 +492,21 @@ class operator(object):
--------
probing : The class used to perform probing operations
"""
if(domain is None):
if domain is None:
domain = self.domain
return trace_probing(self,function=self.times,domain=domain,target=target,random=random,ncpu=(ncpu,1)[bool(loop)],nrun=nrun,nper=nper,var=var,**kwargs)(loop=loop)
def inverse_tr(self,domain=None,target=None,random="pm1",ncpu=2,nrun=8,nper=1,var=False,loop=False,**kwargs):
return trace_probing(self,
function=self.times,
domain=domain,
target=target,
random=random,
ncpu=(ncpu,1)[bool(loop)],
nrun=nrun,
nper=nper,
var=var,
**kwargs)(loop=loop)
def inverse_tr(self, domain=None, target=None, random="pm1", ncpu=2,
nrun=8, nper=1, var=False, loop=False, **kwargs):
"""
Computes the trace of the inverse operator
......@@ -510,11 +550,21 @@ class operator(object):
"""
if(domain is None):
domain = self.target
return trace_probing(self,function=self.inverse_times,domain=domain,target=target,random=random,ncpu=(ncpu,1)[bool(loop)],nrun=nrun,nper=nper,var=var,**kwargs)(loop=loop)
return trace_probing(self,
function=self.inverse_times,
domain=domain,
target=target,
random=random,
ncpu=(ncpu,1)[bool(loop)],
nrun=nrun,
nper=nper,
var=var, **kwargs)(loop=loop)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def diag(self,bare=False,domain=None,target=None,random="pm1",ncpu=2,nrun=8,nper=1,var=False,save=False,path="tmp",prefix="",loop=False,**kwargs):
def diag(self, bare=False, domain=None, target=None, random="pm1", ncpu=2,
nrun=8, nper=1, var=False, save=False, path="tmp", prefix="",
loop=False, **kwargs):
"""
Computes the diagonal of the operator via probing.
......@@ -581,20 +631,35 @@ class operator(object):
"""
if(domain is None):
domain = self.domain
diag = diagonal_probing(self,function=self.times,domain=domain,target=target,random=random,ncpu=(ncpu,1)[bool(loop)],nrun=nrun,nper=nper,var=var,save=save,path=path,prefix=prefix,**kwargs)(loop=loop)
if(diag is None):
diag = diagonal_probing(self,
function=self.times,
domain=domain,
target=target,
random=random,
ncpu=(ncpu,1)[bool(loop)],
nrun=nrun,
nper=nper,
var=var,
save=save,
path=path,
prefix=prefix,
**kwargs)(loop=loop)
if diag is None:
# about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
## weight if ...
elif(not domain.discrete)and(bare):
if(isinstance(diag,tuple)): ## diag == (diag,variance)
return domain.calc_weight(diag[0],power=-1),domain.calc_weight(diag[1],power=-1)
elif domain.discrete == False and bare == True:
if isinstance(diag, tuple): ## diag == (diag,variance)
return (domain.calc_weight(diag[0],power=-1),
domain.calc_weight(diag[1],power=-1))
else:
return domain.calc_weight(diag,power=-1)
else:
return diag
def inverse_diag(self,bare=False,domain=None,target=None,random="pm1",ncpu=2,nrun=8,nper=1,var=False,save=False,path="tmp",prefix="",loop=False,**kwargs):
def inverse_diag(self, bare=False, domain=None, target=None, random="pm1",
ncpu=2, nrun=8, nper=1, var=False, save=False, path="tmp",
prefix="", loop=False, **kwargs):
"""
Computes the diagonal of the inverse operator via probing.
......@@ -661,14 +726,27 @@ class operator(object):
"""
if(domain is None):
domain = self.target
diag = diagonal_probing(self,function=self.inverse_times,domain=domain,target=target,random=random,ncpu=(ncpu,1)[bool(loop)],nrun=nrun,nper=nper,var=var,save=save,path=path,prefix=prefix,**kwargs)(loop=loop)
diag = diagonal_probing(self,
function=self.inverse_times,
domain=domain,
target=target,
random=random,
ncpu=(ncpu,1)[bool(loop)],
nrun=nrun,
nper=nper,
var=var,
save=save,
path=path,
prefix=prefix,
**kwargs)(loop=loop)
if(diag is None):
# about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
## weight if ...
elif(not domain.discrete)and(bare):
if(isinstance(diag,tuple)): ## diag == (diag,variance)
return domain.calc_weight(diag[0],power=-1),domain.calc_weight(diag[1],power=-1)
return (domain.calc_weight(diag[0],power=-1),
domain.calc_weight(diag[1],power=-1))
else:
return domain.calc_weight(diag,power=-1)
else:
......@@ -686,7 +764,8 @@ class operator(object):
The determinant
"""
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'det'."))
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'det'."))
def inverse_det(self):
"""
......@@ -698,7 +777,8 @@ class operator(object):
The determinant
"""
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'inverse_det'."))
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_det'."))
def log_det(self):
"""
......@@ -710,7 +790,8 @@ class operator(object):
The logarithm of the determinant
"""
raise NotImplementedError(about._errors.cstring("ERROR: no generic instance method 'log_det'."))
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'log_det'."))
def tr_log(self):
"""
......@@ -786,13 +867,14 @@ class operator(object):
entries; e.g., as variance in case of an covariance operator.
"""
if(domain is None):
if domain is None:
domain = self.domain
diag = self.diag(bare=bare,domain=domain,target=target,var=False,**kwargs)
if(diag is None):
diag = self.diag(bare=bare, domain=domain, target=target,
var=False, **kwargs)
if diag is None:
about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
return field(domain,val=diag,target=target)
return field(domain, val=diag, target=target)
def inverse_hat(self,bare=False,domain=None,target=None,**kwargs):
"""
......@@ -854,13 +936,14 @@ class operator(object):
entries; e.g., as variance in case of an covariance operator.
"""
if(domain is None):
if domain is None:
domain = self.target
diag = self.inverse_diag(bare=bare,domain=domain,target=target,var=False,**kwargs)
if(diag is None):
diag = self.inverse_diag(bare=bare, domain=domain, target=target,
var=False, **kwargs)
if diag is None:
about.warnings.cprint("WARNING: forwarding 'NoneType'.")
return None
return field(domain,val=diag,target=target)
return field(domain, val=diag, target=target)
def hathat(self,domain=None,**kwargs):
"""
......@@ -917,13 +1000,13 @@ class operator(object):
entries; e.g., as variance in case of an covariance operator.