Commit 0f1e4c12 authored by Ultima's avatar Ultima
Browse files

Implemented multidimensional data on fields.

parent 7f240569
......@@ -2938,7 +2938,7 @@ class field(object):
The space wherein the operator output lives (default: domain).
"""
def __init__(self, domain, val=None, codomain=None, **kwargs):
def __init__(self, domain, val=None, codomain=None, idim=0, **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -2965,21 +2965,24 @@ class field(object):
if not isinstance(domain,space):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
self.domain = domain
## check codomain
if codomain is None:
codomain = domain.get_codomain()
else:
assert(self.domain.check_codomain(codomain))
self.codomain = codomain
self.idim = np.uint(idim)
if val == None:
if kwargs == {}:
self.val = self.domain.cast(0.)
val = self._map(lambda: self.domain.cast(0.))
else:
self.val = self.domain.get_random_values(codomain=self.codomain,
**kwargs)
else:
self.val = val
val = self._map(lambda: self.domain.get_random_values(
codomain=self.codomain,
**kwargs))
self.set_val(new_val = val)
@property
......@@ -2988,21 +2991,111 @@ class field(object):
@val.setter
def val(self, x):
self.__val = self.domain.cast(x)
self.__val = self.cast(x)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _map(self, function, *args):
if self.idim == 0:
return function(*args)
else:
if args == ():
result = []
for i in xrange(self.idim):
result.append(function())
return result
else:
return map(function, *args)
def cast(self, x = None):
if self.idim == 0:
scalarized_x = self._cast_to_scalar_helper(x)
return self.domain.cast(scalarized_x)
else:
vectorized_x = self._cast_to_vector_helper(x)
return self._map(lambda z: self.domain.cast(z),
vectorized_x)
def _cast_to_scalar_helper(self, x):
## If x is a list, take the first entry and cast it with self.domain
if isinstance(x, list):
if len(x) >= 1:
if len(x) > 1:
about.warnings.cprint(
"WARNING: discarding all internal dimensions "+\
"except for the first one.")
return x[0]
## If the given list is empty, cast None.
elif len(x) == 0:
return None
## x is an encapsulated data object of right shape
elif np.shape(x) == (1,) + tuple(self.domain.get_shape()):
return x[0]
## In all other cases, cast x directly
else:
return x
def _cast_to_vector_helper(self, x):
## Check if x is a list of proper length
## containing something which will then checked by the domain-space
if isinstance(x, list):
if len(x) == self.idim:
return x
## Slow track: x could be a pure ndarray
x_shape = np.shape(x)
## Case 1: The overall shape is already the right one
if x_shape == (self.idim,) + tuple(self.domain.get_shape()):
## Iterate over the outermost dimension and cast the inner spaces
result = []
for i in xrange(self.idim):
try:
result.append(x[i].copy())
except(AttributeError):
result.append(np.copy(x[i]))
## Case 2: The overall shape does not match directly.
## Check if the input has shape (1, self.domain.shape)
elif x_shape == (1,) + tuple(self.domain.get_shape()):
## Expand the first entry
result = []
for i in xrange(self.idim):
try:
result.append(x[0].copy())
except(AttributeError):
result.append(np.copy(x[0]))
## Case 3: fallback: try to cast x with self.domain
else:
#Expand as is
result = []
for i in xrange(self.idim):
try:
result.append(x.copy())
except(AttributeError):
result.append(x)
return result
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def copy(self, domain=None, codomain=None):
new_field = self.copy_empty(domain=domain, codomain=codomain)
new_field.val = new_field.domain.cast(self.val.copy())
new_field.val = new_field.cast(self.val.copy())
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def copy_empty(self, domain=None, codomain=None, **kwargs):
def copy_empty(self, domain=None, codomain=None, idim=None, **kwargs):
if domain == None:
domain = self.domain
if codomain == None:
codomain = self.codomain
new_field = field(domain=domain, codomain=codomain, **kwargs)
if idim == None:
idim = self.idim
new_field = field(domain=domain, codomain=codomain, idim=idim,
**kwargs)
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3026,6 +3119,9 @@ class field(object):
"""
return self.domain.dim(split=split)
def get_idim(self):
return self.idim
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast_domain(self, newdomain, new_codomain=None, force=True):
......@@ -3111,6 +3207,7 @@ class field(object):
return self.val
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def set_domain(self, new_domain=None, force=False):
if new_domain is None:
new_domain = self.codomain.get_codomain()
......@@ -3166,9 +3263,12 @@ class field(object):
new_field = self
else:
new_field = self.copy_empty()
new_field.set_val(new_val = self.domain.calc_weight(self.get_val(),
power = power))
new_val = self._map(
lambda y: self.domain.calc_weight(y, power = power),
self.get_val())
new_field.set_val(new_val = new_val)
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3212,16 +3312,21 @@ class field(object):
## Case 3: x is something else
else:
## Cast the input in order to cure datatype and shape differences
casted_x = self.domain.cast(x)
casted_x = self.cast(x)
## Compute the dot respecting the fact of discrete/continous spaces
if self.domain.discrete == True:
return self.domain.calc_dot(self.get_val(), casted_x)
result = self._map(
lambda z1, z2: self.domain.calc_dot(z1, z2),
self.get_val(),
casted_x)
else:
return self.domain.calc_dot(self.get_val(),
self.domain.calc_weight(
casted_x,
power=1))
result = self._map(
lambda z1, z2: self.domain.calc_dot(
z1, self.domain.calc_weight(z2, power = 1)),
self.get_val(), casted_x)
return np.prod(result)
def norm(self, q=0.5):
"""
Computes the Lq-norm of the field values.
......@@ -3371,7 +3476,9 @@ class field(object):
work_field = self
else:
work_field = self.copy_empty()
work_field.set_val(new_val = self.val.conjugate())
new_val = self._map(lambda z: z.conjugate(), self.get_val())
work_field.set_val(new_val = new_val)
return work_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3401,15 +3508,15 @@ class field(object):
Otherwise, nothing is returned.
"""
if(codomain is None):
if codomain is None:
codomain = self.codomain
else:
assert(self.domain.check_codomain(codomain))
new_val = self.domain.calc_transform(self.val,
codomain=codomain,
**kwargs)
new_val = self._map(
lambda z: self.domain.calc_transform(z, codomain=codomain, **kwargs),
self.get_val())
if overwrite == True:
return_field = self
return_field.set_codomain(new_codomain = self.domain, force = True)
......@@ -3451,10 +3558,12 @@ class field(object):
new_field = self
else:
new_field = self.copy_empty()
new_val = self._map(
lambda z: self.domain.calc_smooth(z, sigma = sigma, **kwargs),
self.get_val())
new_field.set_val(new_val = self.domain.calc_smooth(self.get_val(),
sigma = sigma,
**kwargs))
new_field.set_val(new_val = new_val)
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3502,10 +3611,12 @@ class field(object):
kwargs.__delitem__("codomain")
about.warnings.cprint("WARNING: codomain was removed from kwargs.")
return self.domain.calc_power(self.get_val(),
codomain = self.codomain,
**kwargs)
power_spectrum = self._map(
lambda z: self.domain.calc_power(z, codomain = self.codomain,
**kwargs),
self.get_val())
return power_spectrum
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def hat(self):
......@@ -3519,9 +3630,10 @@ class field(object):
"""
from nifty.operators.nifty_operators import diagonal_operator
return diagonal_operator(domain=self.domain,
diag=self.get_val(),
bare=False)
return diagonal_operator(domain = self.domain,
diag = self.get_val(),
bare = False,
idim = self.idim)
def inverse_hat(self):
"""
......@@ -3533,14 +3645,17 @@ class field(object):
The new diagonal operator instance.
"""
if(np.any(self.val==0)):
any_zero_Q = self._map(lambda z: (z==0).any(), self.get_val())
any_zero_Q = np.any(any_zero_Q)
if any_zero_Q == True:
raise AttributeError(
about._errors.cstring("ERROR: singular operator."))
else:
from nifty.operators.nifty_operators import diagonal_operator
return diagonal_operator(domain=self.domain,
diag=(1/self).get_val(),
bare=False)
return diagonal_operator(domain = self.domain,
diag = (1/self).get_val(),
bare = False,
idim = self.idim)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3625,43 +3740,71 @@ class field(object):
return "<nifty_core.field>"
def __str__(self):
minmax = [self.val.amin(), self.val.amax()]
mean = self.val.mean()
minmax = [self.min(), self.max()]
mean = self.mean()
return "nifty_core.field instance\n- domain = " + \
repr(self.domain) + \
"\n- val = [...]" + \
repr(self.domain) +\
"\n- val = " + repr(self.get_val()) + \
"\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean) + \
"\n- codomain = " + repr(self.codomain)
"\n- codomain = " + repr(self.codomain) + \
"\n- idim = " + str(self.idim)
def __len__(self):
return int(self.dim(split=True)[0])
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def __getitem__(self,key):
return self.domain.getitem(self.val, key)
def __setitem__(self,key,value):
self.domain.setitem(self.val, value, key)
def __getitem__(self, key):
if np.isscalar(key) == True or isinstance(key, slice):
key = (key, )
if self.idim == 0:
return self.domain.getitem(self.get_val(), key)
else:
gotten = self.get_val()[key[0]]
if len(key) > 1:
gotten = self._map(lambda z: self.domain.getitem(z, key[1:]),
gotten)
return gotten
def __setitem__(self, key, value):
if np.isscalar(key) or isinstance(key, slice):
key = (key, )
if self.idim == 0:
return self.domain.setitem(self.get_val(), value, key)
else:
gotten = self.get_val()[key[0]]
if len(key) > 1:
gotten = self._map(
lambda z1, z2: self.domain.setitem(z1, key[1:], z2),
gotten, value)
return gotten
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def apply_scalar_function(self, function, inplace=False):
if inplace == True:
temp = self
working_field = self
else:
temp = self.copy_empty()
data_object = self.domain.apply_scalar_function(self.val,
function,
inplace)
temp.set_val(data_object)
return temp
working_field = self.copy_empty()
data_object = self._map(
lambda z: self.domain.apply_scalar_function(z, function, inplace),
self.get_val())
working_field.set_val(data_object)
return working_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def min(self,ignore=False,**kwargs):
def _unary_helper(self, x, op, **kwargs):
result = self._map(
lambda z: self.domain.unary_operation(z, op=op, **kwargs),
self.get_val())
return result
def min(self, ignore=False,**kwargs):
"""
Returns the minimum of the field values.
......@@ -3681,9 +3824,11 @@ class field(object):
"""
if ignore == True:
return self.domain.unary_operation(self.val, op='nanmin', **kwargs)
return self._unary_helper(self.get_val(), op='nanmin',
**kwargs)
else:
return self.domain.unary_operation(self.val, op='min', **kwargs)
return self._unary_helper(self.get_val(), op='min',
**kwargs)
def max(self,ignore=False,**kwargs):
"""
......@@ -3705,10 +3850,11 @@ class field(object):
"""
if ignore == True:
return self.domain.unary_operation(self.val, op='nanmax', **kwargs)
return self._unary_helper(self.get_val(), op='nanmax',
**kwargs)
else:
return self.domain.unary_operation(self.val, op='max', **kwargs)
return self._unary_helper(self.get_val(), op='max',
**kwargs)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def med(self,**kwargs):
......@@ -3725,8 +3871,8 @@ class field(object):
np.median
"""
return self.domain.unary_operation(self.val, op='median', **kwargs)
return self._unary_helper(self.get_val(), op='median',
**kwargs)
def mean(self,**kwargs):
"""
Returns the mean of the field values.
......@@ -3741,8 +3887,8 @@ class field(object):
np.mean
"""
return self.domain.unary_operation(self.val, op='mean', **kwargs)
return self._unary_helper(self.get_val(), op='mean',
**kwargs)
def std(self,**kwargs):
"""
Returns the standard deviation of the field values.
......@@ -3757,8 +3903,8 @@ class field(object):
np.std
"""
return self.domain.unary_operation(self.val, op='std', **kwargs)
return self._unary_helper(self.get_val(), op='std',
**kwargs)
def var(self,**kwargs):
"""
Returns the variance of the field values.
......@@ -3773,8 +3919,8 @@ class field(object):
np.var
"""
return self.domain.unary_operation(self.val, op='var', **kwargs)
return self._unary_helper(self.get_val(), op='var',
**kwargs)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -3801,11 +3947,11 @@ class field(object):
"""
if split == True:
return self.domain.unary_operation(self.val, op='argmin', **kwargs)
return self._unary_helper(self.get_val(), op='argmin',
**kwargs)
else:
return self.domain.unary_operation(self.val,
op='argmin_flat', **kwargs)
return self._unary_helper(self.get_val(), op='argmin_flat',
**kwargs)
def argmax(self,split=True,**kwargs):
"""
Returns the index of the maximum field value.
......@@ -3829,10 +3975,11 @@ class field(object):
"""
if split == True:
return self.domain.unary_operation(self.val, op='argmax', **kwargs)
return self._unary_helper(self.get_val(), op='argmax',
**kwargs)
else:
return self.domain.unary_operation(self.val,
op='argmax_flat', **kwargs)
return self._unary_helper(self.get_val(), op='argmax_flat',
**kwargs)
......@@ -3840,75 +3987,88 @@ class field(object):
def __pos__(self):
new_field = self.copy_empty()
new_field.val = self.domain.unary_operation(self.val, op='pos')
new_val = self._unary_helper(self.get_val(), op='pos')
new_field.set_val(new_val = new_val)
return new_field
def __neg__(self):
new_field = self.copy_empty()
new_field.val = self.domain.unary_operation(self.val, op='neg')
new_val = self._unary_helper(self.get_val(), op='neg')
new_field.set_val(new_val = new_val)
return new_field
def __abs__(self):
new_field = self.copy_empty()
new_field.val = self.domain.unary_operation(self.val, op='abs')
new_val = self._unary_helper(self.get_val(), op='abs')
new_field.set_val(new_val = new_val)
return new_field
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## TODO: Rework to new structure!
def __binary_helper__(self, other, op='None'):
try:
other_val = other.val
except(AttributeError):
other_val = other
new_val = self.domain.binary_operation(self.val, other_val, op=op, cast=0)
new_field = self.copy_empty()
new_field.val = new_val
return new_field
def __inplace_binary_helper__(self, other, op='None'):
def _binary_helper(self, other, op='None', inplace = False):
try:
other_val = other.val
other_val = other.get_val()
except(AttributeError):
other_val = other
self.val = self.domain.binary_operation(self.val, other_val, op=op,
cast=0)
return self
## bring other_val into the right shape
if self.idim == 0:
other_val = self._cast_to_scalar_helper(other_val)
else:
other_val = self._cast_to_vector_helper(other_val)
new_val = self._map(
lambda z1, z2: self.domain.binary_operation(z1, z2, op=op, cast=0),
self.get_val(),
other_val)
if inplace == True:
working_field = self
else:
working_field = self.copy_empty()
working_field.set_val(new_val = new_val)
return working_field
def __add__(self, other):
return self.__binary_helper__(other, op='add')
return self._binary_helper(other, op='add')
__radd__ = __add__
def __iadd__(self, other):
return self.__inplace_binary_helper__(other, op='iadd')
return self._binary_helper(other, op='iadd', inplace = True)
def __sub__(self, other):
return self.__binary_helper__(other, op='sub')
return self._binary_helper(other, op='sub')
def __rsub__(self, other):
return self.__binary_helper__(other, op='rsub')
return self._binary_helper(other, op='rsub')
def __isub__(self, other):
return self.__inplace_binary_helper__(other, op='isub')
return self._binary_helper(other, op='isub', inplace = True)
def __mul__(self, other):
return self.__binary_helper__(other, op='mul')
return self._binary_helper(other, op='mul')
__rmul__ = __mul__
def __imul__(self, other):
return self.__inplace_binary_helper__(other, op='imul')
return self._binary_helper(other, op='imul', inplace = True)
def __div__(self, other):
return self.__binary_helper__(other, op='div')
return self._binary_helper(other, op='div')
def __rdiv__(self, other):
return self.__binary_helper__(other, op='rdiv')
return self._binary_helper(other, op='rdiv')
def __idiv__(self, other):
return self.__inplace_binary_helper__(other, op='idiv')
return self._binary_helper(other, op='idiv', inplace = True)
__truediv__ = __div__
__itruediv__ = __idiv__
def __pow__(self, other):
return self.__binary_helper__(other, op='pow')