Commit 988981a0 authored by Ultima's avatar Ultima
Browse files

Some bugfixes.

Added direct_dot to nifty_simple_math.py.
Begun with direct probing implementation.
parent 148e6fbc
Metadata-Version: 1.0
Name: ift_nifty
Version: 1.0.6
Summary: Numerical Information Field Theory
Home-page: http://www.mpa-garching.mpg.de/ift/nifty/
Author: Theo Steininger
Author-email: theos@mpa-garching.mpg.de
License: GPLv3
Description: UNKNOWN
Platform: UNKNOWN
......@@ -214,9 +214,9 @@ class lm_space(point_space):
return self.paradict['mmax']
def shape(self):
mmax = self.paradict('mmax')
lmax = self.paradict('lmax')
return np.array([(mmax+1)*(lmax+1)-(lmax+1)*(mmax//2)], dtype=int)
mmax = self.paradict['mmax']
lmax = self.paradict['lmax']
return np.array([(mmax+1)*(lmax+1)-((lmax+1)*lmax)//2], dtype=int)
def dim(self,split=False):
"""
......@@ -654,6 +654,7 @@ class lm_space(point_space):
else:
return self._dotlm(x,y)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_transform(self,x,codomain=None,**kwargs):
......
......@@ -2813,13 +2813,7 @@ class field(object):
New field values either as a constant or an arbitrary array.
"""
'''
if(new_val is None):
self.val = np.zeros(self.dim(split=True),dtype=self.domain.datatype,order='C')
else:
self.val = self.domain.enforce_values(new_val,extend=True)
'''
self.val = self.domain.cast(new_val)
self.val = new_val
return self.val
def get_val(self):
......@@ -2913,17 +2907,20 @@ class field(object):
elif isinstance(x, field):
## if x lives in the cospace, transform it an make a
## recursive call
if self.domain.fourier != x.domain.fourier:
return self.dot(x = x.transform())
else:
try:
if self.domain.fourier != x.domain.fourier:
return self.dot(x = x.transform())
except(AttributeError):
pass
## whether the domain matches exactly or not:
## extract the data from x and try to dot with this
return self.dot(x = x.get_val())
return self.dot(x = x.get_val())
## Case 3: x is something else
else:
## Cast the input in order to cure datatype and shape differences
casted_x = self.cast(x)
casted_x = self.domain.cast(x)
## Compute the dot respecting the fact of discrete/continous spaces
if self.domain.discrete == True:
return self.domain.calc_dot(self.get_val(), casted_x)
......
......@@ -411,7 +411,7 @@ class distributed_data_object(object):
found_boolean = (key.dtype == np.bool)
else:
found = 'other'
## TODO: transfer this into distributor:
if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
## extract the data of local relevance
local_bool_array = self.distributor.extract_local_data(key)
......@@ -1541,7 +1541,12 @@ class dtype_converter:
self._to_np_dict = dict(to_np_pre_dict)
def dictionize_np(self, x):
return frozenset(x.__dict__.items())
dic = x.__dict__.items()
if x is np.float:
dic[24] = 0
dic[29] = 0
dic[37] = 0
return frozenset(dic)
def dictionize_mpi(self, x):
return x.name
......
......@@ -483,5 +483,30 @@ def conjugate(x):
"""
return _math_helper(x, np.conjugate)
def direct_dot(x, y):
## the input could be fields. Try to extract the data
try:
x = x.get_val()
except(AttributeError):
pass
## try to make a direct vdot
try:
return x.vdot(y)
except(AttributeError):
pass
try:
return y.vdot(x)
except(AttributeError):
pass
## fallback to numpy
return np.vdot(x, y)
##---------------------------------
\ No newline at end of file
......@@ -29,6 +29,7 @@ from nifty.nifty_core import space, \
from nifty_minimization import conjugate_gradient
from nifty_probing import trace_probing, \
diagonal_probing
from nifty_mpi_probing import prober
##=============================================================================
......@@ -165,6 +166,33 @@ class operator(object):
self.para = para
@property
def val(self):
return self.__val
@val.setter
def val(self, x):
self.__val = self.domain.cast(x)
def set_val(self, new_val):
"""
Resets the field values.
Parameters
----------
new_val : {scalar, ndarray}
New field values either as a constant or an arbitrary array.
"""
self.val = new_val
return self.val
def get_val(self):
return self.val
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def nrow(self):
......@@ -301,7 +329,7 @@ class operator(object):
if self.domain == self.target != x.domain:
x_ = x_.transform(target=x.domain)
if x_.domain == x.domain and (x_.target is not x.target):
x_.set_target(newtarget = x.target)
x_.set_target(new_target = x.target)
return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -1190,27 +1218,9 @@ class diagonal_operator(operator):
self.domain = domain
self.target = self.domain
## Set the diag-val
self.val = self.domain.cast(diag)
## Weight if necessary
if self.domain.discrete == False and bare == True:
self.val = self.domain.calc_weight(self.val, power = 1)
## Check complexity attributes
if self.domain.calc_real_Q(self.val) == True:
self.sym = True
else:
self.sym = False
## Check if unitary, i.e. identity
if (self.val == 1).all() == True:
self.uni = True
else:
self.uni = False
self.imp = True
self.set_diag(new_diag = diag)
"""
if(domain is None)and(isinstance(diag,field)):
......@@ -1243,13 +1253,13 @@ class diagonal_operator(operator):
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_diag(self,newdiag,bare=False):
def set_diag(self, new_diag, bare=False):
"""
Sets the diagonal of the diagonal operator
Parameters
----------
newdiag : {scalar, ndarray, field}
new_diag : {scalar, ndarray, field}
The new diagonal entries of the operator. For a scalar, a
constant diagonal is defined having the value provided. If
no domain is given, diag must be a field.
......@@ -1263,7 +1273,28 @@ class diagonal_operator(operator):
-------
None
"""
newdiag = self.domain.enforce_values(newdiag,extend=True)
## Set the diag-val
self.val = self.domain.cast(new_diag)
## Weight if necessary
if self.domain.discrete == False and bare == True:
self.val = self.domain.calc_weight(self.val, power = 1)
## Check complexity attributes
if self.domain.calc_real_Q(self.val) == True:
self.sym = True
else:
self.sym = False
## Check if unitary, i.e. identity
if (self.val == 1).all() == True:
self.uni = True
else:
self.uni = False
"""
newdiag = self.domain.enforce_values(newdiag, extend=True)
## weight if ...
if(not self.domain.discrete)and(bare):
newdiag = self.domain.calc_weight(newdiag,power=1)
......@@ -1281,21 +1312,50 @@ class diagonal_operator(operator):
self.uni = True
else:
self.uni = False
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _multiply(self,x,**kwargs): ## > applies the operator to a given field
def _multiply(self, x, **kwargs):
## > applies the operator to a given field
y = x.copy(domain = self.target, target = self.target.get_codomain())
y *= self.get_val()
return y
"""
x_ = field(self.target,val=None,target=x.target)
x_.val = x.val*self.val ## bypasses self.domain.enforce_values
return x_
"""
def _adjoint_multiply(self,x,**kwargs): ## > applies the adjoint operator to a given field
def _adjoint_multiply(self, x, **kwargs):
## > applies the adjoint operator to a given field
y = x.copy(domain = self.domain, target = self.domain.get_codomain())
y *= self.get_val().conjugate()
return y
"""
x_ = field(self.domain,val=None,target=x.target)
x_.val = x.val*np.conjugate(self.val) ## bypasses self.domain.enforce_values
return x_
"""
def _inverse_multiply(self,x,pseudo=False,**kwargs): ## > applies the inverse operator to a given field
if(np.any(self.val==0)):
def _inverse_multiply(self, x, pseudo=False, **kwargs):
## > applies the inverse operator to a given field
y = x.copy(domain = self.domain, target = self.domain.get_codomain())
if (self.get_val() == 0).any():
if pseudo == False:
raise AttributeError(about._errors.cstring(
"ERROR: singular operator."))
else:
y /= self.get_val()
y[y == np.nan] = 0
y[y == np.inf] = 0
else:
y /= self.get_val()
return y
"""
if (np.any(self.val==0)):
if(pseudo):
x_ = field(self.domain,val=None,target=x.target)
x_.val = np.ma.filled(x.val/np.ma.masked_where(self.val==0,self.val,copy=False),fill_value=0) ## bypasses self.domain.enforce_values
......@@ -1306,8 +1366,25 @@ class diagonal_operator(operator):
x_ = field(self.domain,val=None,target=x.target)
x_.val = x.val/self.val ## bypasses self.domain.enforce_values
return x_
def _adjoint_inverse_multiply(self,x,pseudo=False,**kwargs): ## > applies the inverse adjoint operator to a given field
"""
def _adjoint_inverse_multiply(self, x, pseudo=False, **kwargs):
## > applies the inverse adjoint operator to a given field
y = x.copy(domain = self.target, target = self.target.get_codomain())
if (self.get_val() == 0).any():
if pseudo == False:
raise AttributeError(about._errors.cstring(
"ERROR: singular operator."))
else:
y /= self.get_val().conjugate()
y[y == np.nan] = 0
y[y == np.inf] = 0
else:
y /= self.get_val().conjugate()
return y
"""
if(np.any(self.val==0)):
if(pseudo):
x_ = field(self.domain,val=None,target=x.target)
......@@ -1319,8 +1396,11 @@ class diagonal_operator(operator):
x_ = field(self.target,val=None,target=x.target)
x_.val = x.val/np.conjugate(self.val) ## bypasses self.domain.enforce_values
return x_
def _inverse_adjoint_multiply(self,x,pseudo=False,**kwargs): ## > applies the adjoint inverse operator to a given field
"""
def _inverse_adjoint_multiply(self, x, pseudo=False, **kwargs):
## > applies the adjoint inverse operator to a given field
return self._adjoint_inverse_multiply(x, pseudo = pseudo, **kwargs)
"""
if(np.any(self.val==0)):
if(pseudo):
x_ = field(self.domain,val=None,target=x.target)
......@@ -1333,7 +1413,7 @@ class diagonal_operator(operator):
x_ = field(self.target,val=None,target=x.target)
x_.val = x.val*np.conjugate(1/self.val) ## bypasses self.domain.enforce_values
return x_
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def tr(self,domain=None,**kwargs):
......@@ -2400,13 +2480,13 @@ class projection_operator(operator):
if(x_.domain.nest[-1]!=x.domain):
x_ = x_.transform(target=nested_space([point_space(len(self.ind),datatype=x.domain.datatype),x.domain]),overwrite=False) ## ... domain
if(x_.target.nest[-1]!=x.target):
x_.set_target(newtarget=nested_space([point_space(len(self.ind),datatype=x.target.datatype),x.target])) ## ... codomain
x_.set_target(new_target=nested_space([point_space(len(self.ind),datatype=x.target.datatype),x.target])) ## ... codomain
else:
## repair ...
if(x_.domain!=x.domain):
x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain
if(x_.target!=x.target):
x_.set_target(newtarget=x.target) ## ... codomain
x_.set_target(new_target=x.target) ## ... codomain
return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3071,7 +3151,7 @@ class response_operator(operator):
if(self.domain==self.target!=x.domain):
x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain
if(x_.domain==x.domain)and(x_.target!=x.target):
x_.set_target(newtarget=x.target) ## ... codomain
x_.set_target(new_target=x.target) ## ... codomain
return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3495,7 +3575,7 @@ class propagator_operator(operator):
if(in_codomain)and(x.domain!=self.codomain):
x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain
if(x_.target!=x.target):
x_.set_target(newtarget=x.target) ## ... codomain
x_.set_target(new_target=x.target) ## ... codomain
return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment