Commit e8fe6581 authored by Ultimanet's avatar Ultimanet
Browse files

Updated nifty_core.field to really delegate everything to its underlying...

Updated nifty_core.field to really delegate everything to its underlying space. The methods pseudo_dot and tensor_dot are not fixed yet.
parent 3e0aa18d
......@@ -894,10 +894,10 @@ class space(object):
def __len__(self):
return int(self.dim(split=False))
## __identiftier__ returns an object which contains all information needed
## _identiftier returns an object which contains all information needed
## to uniquely idetnify a space. It returns a (immutable) tuple which therefore
## can be compored.
def __identifier__(self):
def _identifier(self):
return tuple(sorted(vars(self).items()))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -910,10 +910,10 @@ class space(object):
return mars
def __eq__(self,x): ## __eq__ : self == x
if(isinstance(x,space)):
if(isinstance(x,type(self)))and(np.all(self.para==x.para))and(self.discrete==x.discrete)and(np.all(self.vol==x.vol))and(np.all(self._meta_vars()==x._meta_vars())): ## data types are ignored
return True
return False
if isinstance(x, type(self)):
return self._identifier() == x._identifier()
else:
return False
def __ne__(self,x): ## __ne__ : self <> x
if(isinstance(x,space)):
......@@ -1695,10 +1695,10 @@ class point_space(space):
return x ## T == id
## check codomain
self.check_codomain(codomain) ## a bit pointless
assert(self.check_codomain(codomain))
if(self==codomain):
return x ## T == id
if self == codomain:
return x
else:
raise ValueError(about._errors.cstring("ERROR: unsupported transformation."))
......@@ -2426,7 +2426,7 @@ class nested_space(space):
return x ## T == id
## check codomain
self.check_codomain(codomain) ## a bit pointless
assert(self.check_codomain(codomain))
if(self==codomain)and(coorder is None):
return x ## T == id
......@@ -2657,37 +2657,25 @@ class field(object):
"""
## check domain
if(not isinstance(domain,space)):
if not isinstance(domain,space):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
self.domain = domain
## check codomain
if(target is None):
if target is None:
target = domain.get_codomain()
else:
self.domain.check_codomain(target)
assert(self.domain.check_codomain(target))
self.target = target
if val == None:
if kwargs == {}:
self.val = self.domain.datatype(0)
self.val = self.domain.cast(0)
else:
self.val = self.domain.get_random_values(codomain=self.target,**kwargs)
self.val = self.domain.get_random_values(codomain=self.target,
**kwargs)
else:
self.val = val
"""
self.distributed_val = distributed_data_object(global_shape=domain.dim(split=True), dtype=domain.datatype)
## check values
if(val is None):
self.val = self.domain.get_random_values(codomain=self.target,**kwargs)
else:
self.val = self.domain.enforce_values(val,extend=True)
"""
@property
def val(self):
......@@ -2714,7 +2702,7 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dim(self,split=False):
def dim(self, split=False):
"""
Computes the (array) dimension of the underlying space.
......@@ -2735,16 +2723,16 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast_domain(self,newdomain,newtarget=None,force=True):
def cast_domain(self, newdomain, new_target=None, force=True):
"""
Casts the domain of the field.
Casts the domain of the field.
Parameters
----------
newdomain : space
New space wherein the field should live.
newtarget : space, *optional*
new_target : space, *optional*
Space wherein the transform of the field should live.
When not given, target will automatically be the codomain
of the newly casted domain (default=None).
......@@ -2758,78 +2746,102 @@ class field(object):
Nothing
"""
if(not isinstance(newdomain,space)):
## Check if the newdomain is a space
if not isinstance(newdomain,space):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
elif(newdomain.datatype is not self.domain.datatype):
raise TypeError(about._errors.cstring("ERROR: inequal data types '"+str(np.result_type(newdomain.datatype))+"' and '"+str(np.result_type(self.domain.datatype))+"'."))
elif(newdomain.dim(split=False)!=self.domain.dim(split=False)):
raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(newdomain.dim(split=False))+" <> "+str(self.domain.dim(split=False))+" )."))
if(force):
newshape = newdomain.dim(split=True)
if(not np.all(newshape==self.domain.dim(split=True))):
about.infos.cprint("INFO: reshaping forced.")
self.val.shape = newshape
## Check if the datatypes match
elif newdomain.datatype != self.domain.datatype:
raise TypeError(about._errors.cstring(
"ERROR: inequal data types '" +
str(np.result_type(newdomain.datatype)) +
"' and '" + str(np.result_type(self.domain.datatype)) +
"'."))
## Check if the total dimensions match
elif newdomain.dim() != self.domain.dim():
raise ValueError(about._errors.cstring(
"ERROR: dimension mismatch ( " + str(newdomain.dim()) +
" <> " + str(self.domain.dim()) + " )."))
if force == True:
self.set_domain(new_domain = newdomain, force = True)
else:
if(not np.all(newdomain.dim(split=True)==self.domain.dim(split=True))):
raise ValueError(about._errors.cstring("ERROR: shape mismatch ( "+str(newdomain.dim(split=True))+" <> "+str(self.domain.dim(split=True))+" )."))
self.domain = newdomain
if not np.all(newdomain.dim(split=True) == \
self.domain.dim(split=True)):
raise ValueError(about._errors.cstring(
"ERROR: shape mismatch ( " + str(newdomain.dim(split=True)) +
" <> " + str(self.domain.dim(split=True)) + " )."))
else:
self.domain = newdomain
## Use the casting of the new domain in order to make the old data fit.
self.set_val(new_val = self.val)
## check target
if(newtarget is None):
if(not self.domain.check_codomain(self.target)):
## set the target
if new_target == None:
if not self.domain.check_codomain(self.target):
if(force):
about.infos.cprint("INFO: codomain set to default.")
else:
about.warnings.cprint("WARNING: codomain set to default.")
self.set_target(newtarget=self.domain.get_codomain())
self.set_target(new_target = self.domain.get_codomain())
else:
self.set_target(newtarget=newtarget)
self.set_target(new_target = new_target, force = force)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_val(self, newval):
def set_val(self, new_val):
"""
Resets the field values.
Parameters
----------
newval : {scalar, ndarray}
new_val : {scalar, ndarray}
New field values either as a constant or an arbitrary array.
"""
'''
if(newval is None):
if(new_val is None):
self.val = np.zeros(self.dim(split=True),dtype=self.domain.datatype,order='C')
else:
self.val = self.domain.enforce_values(newval,extend=True)
self.val = self.domain.enforce_values(new_val,extend=True)
'''
self.val = self.domain.cast(newval)
self.val = self.domain.cast(new_val)
return self.val
def get_val(self):
return self.val
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_domain(self, new_domain=None, force=False):
if new_domain is None:
new_domain = self.target.get_codomain()
elif force == False:
assert(self.target.check_codomain(new_domain))
self.domain = new_domain
return self.domain
def set_target(self,newtarget=None):
def set_target(self, new_target=None, force=False):
"""
Resets the codomain of the field.
Parameters
----------
newtarget : space
new_target : space
The new space wherein the transform of the field should live.
(default=None).
"""
## check codomain
if(newtarget is None):
newtarget = self.domain.get_codomain()
else:
self.domain.check_codomain(newtarget)
self.target = newtarget
if new_target is None:
new_target = self.domain.get_codomain()
elif force == False:
assert(self.domain.check_codomain(new_target))
self.target = new_target
return self.target
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def weight(self,power=1,overwrite=False):
def weight(self, power=1, overwrite=False):
"""
Returns the field values, weighted with the volume factors to a
given power. The field values will optionally be overwritten.
......@@ -2850,14 +2862,17 @@ class field(object):
Otherwise, nothing is returned.
"""
if(overwrite):
self.val = self.domain.calc_weight(self.val,power=power)
if overwrite == True:
new_field = self
else:
return field(self.domain,val=self.domain.calc_weight(self.val,power=power),target=self.target)
new_field = self.copy_empty()
new_field.set_val(new_val = self.domain.calc_weight(self.get_val(),
power = power))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def dot(self,x=None):
def dot(self, x=None):
"""
Computes the inner product of the field with a given object
implying the correct volume factor needed to reflect the
......@@ -2875,42 +2890,35 @@ class field(object):
The result of the inner product.
"""
if(x is None):
x = self.val
if(isinstance(x,field)):
if(self.domain!=x.domain):
try: ## to transform field
x = x.transform(target=self.domain,overwrite=False)
except(ValueError):
if(np.size(x.dim(split=True))>np.size(self.dim(split=True))): ## switch
return x.dot(x=self)
else:
try: ## to complete subfield
x = field(self.domain,val=x,target=self.target)
except(TypeError,ValueError):
try: ## to complete transformed subfield
x = field(self.domain,val=x.transform(target=x.target,overwrite=False),target=self.target)
except(TypeError,ValueError):
raise ValueError(about._errors.cstring("ERROR: incompatible domains."))
if(x.domain.datatype>self.domain.datatype):
if(not self.domain.discrete):
return x.domain.calc_dot(self.val.astype(x.domain.datatype),x.weight(power=1,overwrite=False))
else:
return x.domain.calc_dot(self.val.astype(x.domain.datatype),x.val)
## Case 1: x equals None
if x == None:
return None
## Case 2: x is a field
elif isinstance(x, field):
## if x lives in the cospace, transform it an make a
## recursive call
if self.domain.fourier != x.domain.fourier:
return self.dot(x = x.transform())
else:
if(not self.domain.discrete):
return self.domain.calc_dot(self.val,self.domain.calc_weight(x.val,power=1))
#return self.domain.calc_dot(self.val,self.domain.calc_weight(x.val.astype(self.domain.datatype),power=1))
else:
return self.domain.calc_dot(self.val,x.val)
#return self.domain.calc_dot(self.val,x.val.astype(self.domain.datatype))
## whether the domain matches exactly or not:
## extract the data from x and try to dot with this
return self.dot(x = x.get_val())
## Case 3: x is something else
else:
x = self.domain.enforce_values(x,extend=True)
if(not self.domain.discrete):
x = self.domain.calc_weight(x,power=1)
return self.domain.calc_dot(self.val,x)
## Cast the input in order to cure datatype and shape differences
casted_x = self.cast(x)
## Compute the dot respecting the fact of discrete/continous spaces
if self.domain.discrete == True:
return self.domain.calc_dot(self.get_val(), casted_x)
else:
return self.domain.calc_dot(self.get_val(),
self.domain.calc_weight(
casted_x,
power=1))
def norm(self,q=None):
def norm(self, q=0.5):
"""
Computes the Lq-norm of the field values.
......@@ -2925,12 +2933,14 @@ class field(object):
The Lq-norm of the field values.
"""
if(q is None):
return np.sqrt(self.dot(x=self.val))
if q == 0.5:
return (self.dot(x = self))**(1/2)
else:
return self.dot(x=self.val**(q-1))**(1/q)
return self.dot(x = self**(q-1))**(1/q)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## TODO: rework the nested space semantics in order to become compatible
## with the usual space interface
def pseudo_dot(self,x=1,**kwargs):
"""
......@@ -3007,6 +3017,8 @@ class field(object):
return self.domain.calc_pseudo_dot(self.val,x,**kwargs)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## TODO: rework the nested space semantics in order to become compatible
## with the usual space interface
def tensor_dot(self,x=None,**kwargs):
"""
......@@ -3041,7 +3053,7 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def conjugate(self):
def conjugate(self, inplace=False):
"""
Computes the complex conjugate of the field.
......@@ -3051,11 +3063,16 @@ class field(object):
The complex conjugated field.
"""
return field(self.domain,val=np.conjugate(self.val),target=self.target)
if inplace == True:
work_field = self
else:
work_field = self.copy_empty()
work_field.set_val(new_val = self.val.conjugate())
return work_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def transform(self, target=None, overwrite=False, **kwargs):
def transform(self, target=None, overwrite=False, **kwargs):
"""
Computes the transform of the field using the appropriate conjugate
transformation.
......@@ -3083,21 +3100,25 @@ class field(object):
if(target is None):
target = self.target
else:
self.domain.check_codomain(target) ## a bit pointless
assert(self.domain.check_codomain(target))
new_val = self.domain.calc_transform(self.val,
codomain=target,
**kwargs)
if(overwrite):
self.val = new_val
self.target = self.domain
self.domain = target
if overwrite == True:
return_field = self
return_field.set_target(new_target = self.domain, force = True)
return_field.set_domain(new_domain = target, force = True)
else:
return field(target, val=new_val, target=self.domain)
return_field = self.copy_empty(domain = self.target,
target = self.domain)
return_field.set_val(new_val = new_val)
return return_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def smooth(self,sigma=0,overwrite=False,**kwargs):
def smooth(self, sigma=0, overwrite=False, **kwargs):
"""
Smoothes the field by convolution with a Gaussian kernel.
......@@ -3122,14 +3143,19 @@ class field(object):
Otherwise, nothing is returned.
"""
if(overwrite):
self.val = self.domain.calc_smooth(self.val,sigma=sigma,**kwargs)
if overwrite == True:
new_field = self
else:
return field(self.domain,val=self.domain.calc_smooth(self.val,sigma=sigma,**kwargs),target=self.target)
new_field = self.copy_empty()
new_field.set_val(new_val = self.domain.calc_smooth(self.get_val(),
sigma = sigma,
**kwargs))
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def power(self,**kwargs):
def power(self, **kwargs):
"""
Computes the power spectrum of the field values.
......@@ -3170,7 +3196,11 @@ class field(object):
"""
if("codomain" in kwargs):
kwargs.__delitem__("codomain")
return self.domain.calc_power(self.val,codomain=self.target,**kwargs)
about.warnings.cprint("WARNING: codomain was removed from kwargs.")
return self.domain.calc_power(self.get_val(),
codomain = self.target,
**kwargs)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3185,7 +3215,9 @@ class field(object):
"""
from nifty.operators.nifty_operators import diagonal_operator
return diagonal_operator(domain=self.domain,diag=self.val,bare=False)
return diagonal_operator(domain=self.domain,
diag=self.get_val(),
bare=False)
def inverse_hat(self):
"""
......@@ -3198,10 +3230,13 @@ class field(object):
"""
if(np.any(self.val==0)):
raise AttributeError(about._errors.cstring("ERROR: singular operator."))
raise AttributeError(
about._errors.cstring("ERROR: singular operator."))
else:
from nifty.operators.nifty_operators import diagonal_operator
return diagonal_operator(domain=self.domain,diag=1/self.val,bare=False)
return diagonal_operator(domain=self.domain,
diag=(1/self).get_val(),
bare=False)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3268,10 +3303,7 @@ class field(object):
"""
## if a save path is given, set pylab to not-interactive
remember_interactive = pl.isinteractive()
pl.matplotlib.interactive(not bool(
kwargs.get("save", False)
)
)
pl.matplotlib.interactive(not bool(kwargs.get("save", False)))
if "codomain" in kwargs:
kwargs.__delitem__("codomain")
......@@ -3289,9 +3321,14 @@ class field(object):
return "<nifty_core.field>"
def __str__(self):
minmax = [np.min(self.val,axis=None,out=None),np.max(self.val,axis=None,out=None)]
medmean = [np.median(self.val,axis=None,out=None,overwrite_input=False),np.mean(self.val,axis=None,dtype=self.domain.datatype,out=None)]
return "nifty_core.field instance\n- domain = "+repr(self.domain)+"\n- val = [...]"+"\n - min.,max. = "+str(minmax)+"\n - med.,mean = "+str(medmean)+"\n- target = "+repr(self.target)
minmax = [self.val.amin(), self.val.amax()]
mean = self.val.mean()
return "nifty_core.field instance\n- domain = " + \
repr(self.domain) + \
"\n- val = [...]" + \
"\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean) + \
"\n- target = " + repr(self.target)
def __len__(self):
......@@ -3514,14 +3551,14 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def __binary_helper__(self, other, op='None'):
new_val = self.domain.binary_operation(self.val, other, op=op, cast=3)
new_val = self.domain.binary_operation(self.val, other, op=op, cast=0)
new_field = self.copy_empty()
new_field.val = new_val
return new_field
def __inplace_binary_helper__(self, other, op='None'):
self.val = self.domain.binary_operation(self.val, other, op=op,
cast=3)
cast=0)
return self
def __add__(self, other):
......
......@@ -1747,7 +1747,7 @@ class rg_space(point_space):
## therefore can be compared.
## The rg_space version of __identifier__ filters out the vars-information
## which is describing the rg_space's structure
def __identifier__(self):
def _identifier(self):
## Extract the identifying parts from the vars(self) dict.
temp = [(ii[0],
((lambda x: tuple(x) if isinstance(x,np.ndarray) else x)(ii[1])))\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment