Commit 988981a0 authored by Ultima's avatar Ultima
Browse files

Some bugfixes.

Added direct_dot to nifty_simple_math.py.
Begun with direct probing implementation.
parent 148e6fbc
Metadata-Version: 1.0
Name: ift_nifty
Version: 1.0.6
Summary: Numerical Information Field Theory
Home-page: http://www.mpa-garching.mpg.de/ift/nifty/
Author: Theo Steininger
Author-email: theos@mpa-garching.mpg.de
License: GPLv3
Description: UNKNOWN
Platform: UNKNOWN
...@@ -214,9 +214,9 @@ class lm_space(point_space): ...@@ -214,9 +214,9 @@ class lm_space(point_space):
return self.paradict['mmax'] return self.paradict['mmax']
def shape(self): def shape(self):
mmax = self.paradict('mmax') mmax = self.paradict['mmax']
lmax = self.paradict('lmax') lmax = self.paradict['lmax']
return np.array([(mmax+1)*(lmax+1)-(lmax+1)*(mmax//2)], dtype=int) return np.array([(mmax+1)*(lmax+1)-((lmax+1)*lmax)//2], dtype=int)
def dim(self,split=False): def dim(self,split=False):
""" """
...@@ -654,6 +654,7 @@ class lm_space(point_space): ...@@ -654,6 +654,7 @@ class lm_space(point_space):
else: else:
return self._dotlm(x,y) return self._dotlm(x,y)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_transform(self,x,codomain=None,**kwargs): def calc_transform(self,x,codomain=None,**kwargs):
......
...@@ -2813,13 +2813,7 @@ class field(object): ...@@ -2813,13 +2813,7 @@ class field(object):
New field values either as a constant or an arbitrary array. New field values either as a constant or an arbitrary array.
""" """
''' self.val = new_val
if(new_val is None):
self.val = np.zeros(self.dim(split=True),dtype=self.domain.datatype,order='C')
else:
self.val = self.domain.enforce_values(new_val,extend=True)
'''
self.val = self.domain.cast(new_val)
return self.val return self.val
def get_val(self): def get_val(self):
...@@ -2913,17 +2907,20 @@ class field(object): ...@@ -2913,17 +2907,20 @@ class field(object):
elif isinstance(x, field): elif isinstance(x, field):
## if x lives in the cospace, transform it an make a ## if x lives in the cospace, transform it an make a
## recursive call ## recursive call
if self.domain.fourier != x.domain.fourier: try:
return self.dot(x = x.transform()) if self.domain.fourier != x.domain.fourier:
else: return self.dot(x = x.transform())
except(AttributeError):
pass
## whether the domain matches exactly or not: ## whether the domain matches exactly or not:
## extract the data from x and try to dot with this ## extract the data from x and try to dot with this
return self.dot(x = x.get_val()) return self.dot(x = x.get_val())
## Case 3: x is something else ## Case 3: x is something else
else: else:
## Cast the input in order to cure datatype and shape differences ## Cast the input in order to cure datatype and shape differences
casted_x = self.cast(x) casted_x = self.domain.cast(x)
## Compute the dot respecting the fact of discrete/continous spaces ## Compute the dot respecting the fact of discrete/continous spaces
if self.domain.discrete == True: if self.domain.discrete == True:
return self.domain.calc_dot(self.get_val(), casted_x) return self.domain.calc_dot(self.get_val(), casted_x)
......
...@@ -411,7 +411,7 @@ class distributed_data_object(object): ...@@ -411,7 +411,7 @@ class distributed_data_object(object):
found_boolean = (key.dtype == np.bool) found_boolean = (key.dtype == np.bool)
else: else:
found = 'other' found = 'other'
## TODO: transfer this into distributor:
if (found == 'ndarray' or found == 'd2o') and found_boolean == True: if (found == 'ndarray' or found == 'd2o') and found_boolean == True:
## extract the data of local relevance ## extract the data of local relevance
local_bool_array = self.distributor.extract_local_data(key) local_bool_array = self.distributor.extract_local_data(key)
...@@ -1541,7 +1541,12 @@ class dtype_converter: ...@@ -1541,7 +1541,12 @@ class dtype_converter:
self._to_np_dict = dict(to_np_pre_dict) self._to_np_dict = dict(to_np_pre_dict)
def dictionize_np(self, x): def dictionize_np(self, x):
return frozenset(x.__dict__.items()) dic = x.__dict__.items()
if x is np.float:
dic[24] = 0
dic[29] = 0
dic[37] = 0
return frozenset(dic)
def dictionize_mpi(self, x): def dictionize_mpi(self, x):
return x.name return x.name
......
...@@ -483,5 +483,30 @@ def conjugate(x): ...@@ -483,5 +483,30 @@ def conjugate(x):
""" """
return _math_helper(x, np.conjugate) return _math_helper(x, np.conjugate)
def direct_dot(x, y):
## the input could be fields. Try to extract the data
try:
x = x.get_val()
except(AttributeError):
pass
## try to make a direct vdot
try:
return x.vdot(y)
except(AttributeError):
pass
try:
return y.vdot(x)
except(AttributeError):
pass
## fallback to numpy
return np.vdot(x, y)
##--------------------------------- ##---------------------------------
\ No newline at end of file
...@@ -29,6 +29,7 @@ from nifty.nifty_core import space, \ ...@@ -29,6 +29,7 @@ from nifty.nifty_core import space, \
from nifty_minimization import conjugate_gradient from nifty_minimization import conjugate_gradient
from nifty_probing import trace_probing, \ from nifty_probing import trace_probing, \
diagonal_probing diagonal_probing
from nifty_mpi_probing import prober
##============================================================================= ##=============================================================================
...@@ -165,6 +166,33 @@ class operator(object): ...@@ -165,6 +166,33 @@ class operator(object):
self.para = para self.para = para
@property
def val(self):
return self.__val
@val.setter
def val(self, x):
self.__val = self.domain.cast(x)
def set_val(self, new_val):
"""
Resets the field values.
Parameters
----------
new_val : {scalar, ndarray}
New field values either as a constant or an arbitrary array.
"""
self.val = new_val
return self.val
def get_val(self):
return self.val
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def nrow(self): def nrow(self):
...@@ -301,7 +329,7 @@ class operator(object): ...@@ -301,7 +329,7 @@ class operator(object):
if self.domain == self.target != x.domain: if self.domain == self.target != x.domain:
x_ = x_.transform(target=x.domain) x_ = x_.transform(target=x.domain)
if x_.domain == x.domain and (x_.target is not x.target): if x_.domain == x.domain and (x_.target is not x.target):
x_.set_target(newtarget = x.target) x_.set_target(new_target = x.target)
return x_ return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
...@@ -1190,27 +1218,9 @@ class diagonal_operator(operator): ...@@ -1190,27 +1218,9 @@ class diagonal_operator(operator):
self.domain = domain self.domain = domain
self.target = self.domain self.target = self.domain
## Set the diag-val
self.val = self.domain.cast(diag)
## Weight if necessary
if self.domain.discrete == False and bare == True:
self.val = self.domain.calc_weight(self.val, power = 1)
## Check complexity attributes
if self.domain.calc_real_Q(self.val) == True:
self.sym = True
else:
self.sym = False
## Check if unitary, i.e. identity
if (self.val == 1).all() == True:
self.uni = True
else:
self.uni = False
self.imp = True self.imp = True
self.set_diag(new_diag = diag)
""" """
if(domain is None)and(isinstance(diag,field)): if(domain is None)and(isinstance(diag,field)):
...@@ -1243,13 +1253,13 @@ class diagonal_operator(operator): ...@@ -1243,13 +1253,13 @@ class diagonal_operator(operator):
""" """
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_diag(self,newdiag,bare=False): def set_diag(self, new_diag, bare=False):
""" """
Sets the diagonal of the diagonal operator Sets the diagonal of the diagonal operator
Parameters Parameters
---------- ----------
newdiag : {scalar, ndarray, field} new_diag : {scalar, ndarray, field}
The new diagonal entries of the operator. For a scalar, a The new diagonal entries of the operator. For a scalar, a
constant diagonal is defined having the value provided. If constant diagonal is defined having the value provided. If
no domain is given, diag must be a field. no domain is given, diag must be a field.
...@@ -1263,7 +1273,28 @@ class diagonal_operator(operator): ...@@ -1263,7 +1273,28 @@ class diagonal_operator(operator):
------- -------
None None
""" """
newdiag = self.domain.enforce_values(newdiag,extend=True)
## Set the diag-val
self.val = self.domain.cast(new_diag)
## Weight if necessary
if self.domain.discrete == False and bare == True:
self.val = self.domain.calc_weight(self.val, power = 1)
## Check complexity attributes
if self.domain.calc_real_Q(self.val) == True:
self.sym = True
else:
self.sym = False
## Check if unitary, i.e. identity
if (self.val == 1).all() == True:
self.uni = True
else:
self.uni = False
"""
newdiag = self.domain.enforce_values(newdiag, extend=True)
## weight if ... ## weight if ...
if(not self.domain.discrete)and(bare): if(not self.domain.discrete)and(bare):
newdiag = self.domain.calc_weight(newdiag,power=1) newdiag = self.domain.calc_weight(newdiag,power=1)
...@@ -1281,21 +1312,50 @@ class diagonal_operator(operator): ...@@ -1281,21 +1312,50 @@ class diagonal_operator(operator):
self.uni = True self.uni = True
else: else:
self.uni = False self.uni = False
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _multiply(self,x,**kwargs): ## > applies the operator to a given field def _multiply(self, x, **kwargs):
## > applies the operator to a given field
y = x.copy(domain = self.target, target = self.target.get_codomain())
y *= self.get_val()
return y
"""
x_ = field(self.target,val=None,target=x.target) x_ = field(self.target,val=None,target=x.target)
x_.val = x.val*self.val ## bypasses self.domain.enforce_values x_.val = x.val*self.val ## bypasses self.domain.enforce_values
return x_ return x_
"""
def _adjoint_multiply(self,x,**kwargs): ## > applies the adjoint operator to a given field def _adjoint_multiply(self, x, **kwargs):
## > applies the adjoint operator to a given field
y = x.copy(domain = self.domain, target = self.domain.get_codomain())
y *= self.get_val().conjugate()
return y
"""
x_ = field(self.domain,val=None,target=x.target) x_ = field(self.domain,val=None,target=x.target)
x_.val = x.val*np.conjugate(self.val) ## bypasses self.domain.enforce_values x_.val = x.val*np.conjugate(self.val) ## bypasses self.domain.enforce_values
return x_ return x_
"""
def _inverse_multiply(self,x,pseudo=False,**kwargs): ## > applies the inverse operator to a given field def _inverse_multiply(self, x, pseudo=False, **kwargs):
if(np.any(self.val==0)): ## > applies the inverse operator to a given field
y = x.copy(domain = self.domain, target = self.domain.get_codomain())
if (self.get_val() == 0).any():
if pseudo == False:
raise AttributeError(about._errors.cstring(
"ERROR: singular operator."))
else:
y /= self.get_val()
y[y == np.nan] = 0
y[y == np.inf] = 0
else:
y /= self.get_val()
return y
"""
if (np.any(self.val==0)):
if(pseudo): if(pseudo):
x_ = field(self.domain,val=None,target=x.target) x_ = field(self.domain,val=None,target=x.target)
x_.val = np.ma.filled(x.val/np.ma.masked_where(self.val==0,self.val,copy=False),fill_value=0) ## bypasses self.domain.enforce_values x_.val = np.ma.filled(x.val/np.ma.masked_where(self.val==0,self.val,copy=False),fill_value=0) ## bypasses self.domain.enforce_values
...@@ -1306,8 +1366,25 @@ class diagonal_operator(operator): ...@@ -1306,8 +1366,25 @@ class diagonal_operator(operator):
x_ = field(self.domain,val=None,target=x.target) x_ = field(self.domain,val=None,target=x.target)
x_.val = x.val/self.val ## bypasses self.domain.enforce_values x_.val = x.val/self.val ## bypasses self.domain.enforce_values
return x_ return x_
"""
def _adjoint_inverse_multiply(self,x,pseudo=False,**kwargs): ## > applies the inverse adjoint operator to a given field
def _adjoint_inverse_multiply(self, x, pseudo=False, **kwargs):
## > applies the inverse adjoint operator to a given field
y = x.copy(domain = self.target, target = self.target.get_codomain())
if (self.get_val() == 0).any():
if pseudo == False:
raise AttributeError(about._errors.cstring(
"ERROR: singular operator."))
else:
y /= self.get_val().conjugate()
y[y == np.nan] = 0
y[y == np.inf] = 0
else:
y /= self.get_val().conjugate()
return y
"""
if(np.any(self.val==0)): if(np.any(self.val==0)):
if(pseudo): if(pseudo):
x_ = field(self.domain,val=None,target=x.target) x_ = field(self.domain,val=None,target=x.target)
...@@ -1319,8 +1396,11 @@ class diagonal_operator(operator): ...@@ -1319,8 +1396,11 @@ class diagonal_operator(operator):
x_ = field(self.target,val=None,target=x.target) x_ = field(self.target,val=None,target=x.target)
x_.val = x.val/np.conjugate(self.val) ## bypasses self.domain.enforce_values x_.val = x.val/np.conjugate(self.val) ## bypasses self.domain.enforce_values
return x_ return x_
"""
def _inverse_adjoint_multiply(self,x,pseudo=False,**kwargs): ## > applies the adjoint inverse operator to a given field def _inverse_adjoint_multiply(self, x, pseudo=False, **kwargs):
## > applies the adjoint inverse operator to a given field
return self._adjoint_inverse_multiply(x, pseudo = pseudo, **kwargs)
"""
if(np.any(self.val==0)): if(np.any(self.val==0)):
if(pseudo): if(pseudo):
x_ = field(self.domain,val=None,target=x.target) x_ = field(self.domain,val=None,target=x.target)
...@@ -1333,7 +1413,7 @@ class diagonal_operator(operator): ...@@ -1333,7 +1413,7 @@ class diagonal_operator(operator):
x_ = field(self.target,val=None,target=x.target) x_ = field(self.target,val=None,target=x.target)
x_.val = x.val*np.conjugate(1/self.val) ## bypasses self.domain.enforce_values x_.val = x.val*np.conjugate(1/self.val) ## bypasses self.domain.enforce_values
return x_ return x_
"""
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def tr(self,domain=None,**kwargs): def tr(self,domain=None,**kwargs):
...@@ -2400,13 +2480,13 @@ class projection_operator(operator): ...@@ -2400,13 +2480,13 @@ class projection_operator(operator):
if(x_.domain.nest[-1]!=x.domain): if(x_.domain.nest[-1]!=x.domain):
x_ = x_.transform(target=nested_space([point_space(len(self.ind),datatype=x.domain.datatype),x.domain]),overwrite=False) ## ... domain x_ = x_.transform(target=nested_space([point_space(len(self.ind),datatype=x.domain.datatype),x.domain]),overwrite=False) ## ... domain
if(x_.target.nest[-1]!=x.target): if(x_.target.nest[-1]!=x.target):
x_.set_target(newtarget=nested_space([point_space(len(self.ind),datatype=x.target.datatype),x.target])) ## ... codomain x_.set_target(new_target=nested_space([point_space(len(self.ind),datatype=x.target.datatype),x.target])) ## ... codomain
else: else:
## repair ... ## repair ...
if(x_.domain!=x.domain): if(x_.domain!=x.domain):
x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain
if(x_.target!=x.target): if(x_.target!=x.target):
x_.set_target(newtarget=x.target) ## ... codomain x_.set_target(new_target=x.target) ## ... codomain
return x_ return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
...@@ -3071,7 +3151,7 @@ class response_operator(operator): ...@@ -3071,7 +3151,7 @@ class response_operator(operator):
if(self.domain==self.target!=x.domain): if(self.domain==self.target!=x.domain):
x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain
if(x_.domain==x.domain)and(x_.target!=x.target): if(x_.domain==x.domain)and(x_.target!=x.target):
x_.set_target(newtarget=x.target) ## ... codomain x_.set_target(new_target=x.target) ## ... codomain
return x_ return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
...@@ -3495,7 +3575,7 @@ class propagator_operator(operator): ...@@ -3495,7 +3575,7 @@ class propagator_operator(operator):
if(in_codomain)and(x.domain!=self.codomain): if(in_codomain)and(x.domain!=self.codomain):
x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain x_ = x_.transform(target=x.domain,overwrite=False) ## ... domain
if(x_.target!=x.target): if(x_.target!=x.target):
x_.set_target(newtarget=x.target) ## ... codomain x_.set_target(new_target=x.target) ## ... codomain
return x_ return x_
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment