Commit 23e0643e authored by Ultima's avatar Ultima
Browse files

Improved the multidimensionality of fields; a lot.

Added some ufuncs to point_space, field and distributed_data_object.
Updated the projection_operator.
parent 0f1e4c12
......@@ -170,7 +170,14 @@ class lm_space(point_space):
about.warnings.cprint("WARNING: data type set to default.")
datatype = np.complex128
self.datatype = datatype
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = True
self.vol = np.real(np.array([1],dtype=self.datatype))
......@@ -1008,7 +1015,7 @@ class gl_space(point_space):
vol : numpy.ndarray
An array containing the pixel sizes.
"""
def __init__(self,nlat,nlon=None,datatype=None):
def __init__(self, nlat, nlon=None, datatype=None, datamodel='np'):
"""
Sets the attributes for a gl_space class instance.
......@@ -1047,6 +1054,13 @@ class gl_space(point_space):
about.warnings.cprint("WARNING: data type set to default.")
datatype = np.float64
self.datatype = datatype
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = False
self.vol = gl.vol(self.paradict['nlat'],nlon=self.paradict['nlon']).astype(self.datatype)
......@@ -1702,7 +1716,7 @@ class hp_space(point_space):
"""
niter = 0 ## default number of iterations used for transformations
def __init__(self, nside):
def __init__(self, nside, datamodel = 'np'):
"""
Sets the attributes for a hp_space class instance.
......@@ -1731,6 +1745,14 @@ class hp_space(point_space):
self.paradict = hp_space_paradict(nside=nside)
self.datatype = np.float64
## set datamodel
if datamodel not in ['np']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = 'np'
else:
self.datamodel = datamodel
self.discrete = False
self.vol = np.array([4*pi/(12*self.paradict['nside']**2)],dtype=self.datatype)
......
......@@ -143,6 +143,7 @@
from __future__ import division
import numpy as np
import pylab as pl
from nifty_paradict import space_paradict,\
point_space_paradict,\
nested_space_paradict
......@@ -904,7 +905,7 @@ class space(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def __len__(self):
return int(self.dim(split=False))
return int(self.get_dim(split=False))
## _identiftier returns an object which contains all information needed
## to uniquely idetnify a space. It returns a (immutable) tuple which therefore
......@@ -1088,7 +1089,8 @@ class point_space(space):
## check datatype
if (datatype is None):
datatype = np.float64
elif (datatype not in [np.int8,
elif (datatype not in [np.bool_,
np.int8,
np.int16,
np.int32,
np.int64,
......@@ -1202,6 +1204,8 @@ class point_space(space):
"conjugate" : np.conjugate,
"sum" : np.sum,
"prod" : np.prod,
"unique" : np.unique,
"copy" : np.copy,
"None" : lambda y: y}
elif self.datamodel == 'd2o':
......@@ -1223,6 +1227,8 @@ class point_space(space):
"conjugate" : lambda y: getattr(y, 'conjugate')(),
"sum" : lambda y: getattr(y, 'sum')(),
"prod" : lambda y: getattr(y, 'prod')(),
"unique" : lambda y: getattr(y, 'unique')(),
"copy" : lambda y: getattr(y, 'copy')(),
"None" : lambda y: y}
else:
raise NotImplementedError(about._errors.cstring(
......@@ -1248,6 +1254,12 @@ class point_space(space):
"pow" : lambda z: getattr(z, '__pow__'),
"rpow" : lambda z: getattr(z, '__rpow__'),
"ipow" : lambda z: getattr(z, '__ipow__'),
"ne" : lambda z: getattr(z, '__ne__'),
"lt" : lambda z: getattr(z, '__lt__'),
"le" : lambda z: getattr(z, '__le__'),
"eq" : lambda z: getattr(z, '__eq__'),
"ge" : lambda z: getattr(z, '__ge__'),
"gt" : lambda z: getattr(z, '__gt__'),
"None" : lambda z: lambda u: u}
if (cast & 1) != 0:
......@@ -1412,16 +1424,21 @@ class point_space(space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def cast(self, x, verbose = False):
def cast(self, x, dtype = None, verbose = False, **kwargs):
if dtype is not None:
dtype = np.dtype(dtype).type
if self.datamodel == 'd2o':
return self._cast_to_d2o(x = x, verbose = False)
return self._cast_to_d2o(x = x, dtype = dtype, verbose = verbose,
**kwargs)
elif self.datamodel == 'np':
return self._cast_to_np(x = x, verbose = False)
return self._cast_to_np(x = x, dtype = dtype, verbose = verbose,
**kwargs)
else:
raise NotImplementedError(about._errors.cstring(
"ERROR: function is not implemented for given datamodel."))
def _cast_to_d2o(self, x, verbose=False):
def _cast_to_d2o(self, x, dtype=None, verbose=False, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
......@@ -1444,6 +1461,9 @@ class point_space(space):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is None:
dtype = self.datatype
## Case 1: x is a field
if isinstance(x, field):
if verbose:
......@@ -1452,14 +1472,14 @@ class point_space(space):
about.warnings.cflush(\
"WARNING: Getting data from foreign domain!")
## Extract the data, whatever it is, and cast it again
return self.cast(x.val)
return self.cast(x.val, dtype=dtype)
## Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object):
## Check the shape
if np.any(x.shape != self.get_shape()):
## Check if at least the number of degrees of freedom is equal
if x.dim() == self.get_dim():
if x.get_dim() == self.get_dim():
## If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(\
"WARNING: Trying to reshape the data. This operation is "+\
......@@ -1467,20 +1487,20 @@ class point_space(space):
temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape())
## ... and cast again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
else:
raise ValueError(about._errors.cstring(\
"ERROR: Data has incompatible shape!"))
## Check the datatype
if x.dtype < self.datatype:
if x.dtype != dtype:
about.warnings.cflush(\
"WARNING: Datatypes are uneqal/of conflicting precision (own: "\
+ str(self.datatype) + " <> foreign: " + str(x.dtype) \
+ str(dtype) + " <> foreign: " + str(x.dtype) \
+ ") and will be casted! "\
+ "Potential loss of precision!\n")
temp = x.copy_empty(dtype=self.datatype)
temp = x.copy_empty(dtype=dtype)
temp.set_local_data(x.get_local_data())
temp.hermitian = x.hermitian
x = temp
......@@ -1490,11 +1510,11 @@ class point_space(space):
## Case 3: x is something else
## Use general d2o casting
x = distributed_data_object(x, global_shape=self.get_shape(),\
dtype=self.datatype)
dtype=dtype)
## Cast the d2o
return self.cast(x)
return self.cast(x, dtype=dtype)
def _cast_to_np(self, x, verbose = False):
def _cast_to_np(self, x, dtype = None, verbose = False, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
......@@ -1517,6 +1537,9 @@ class point_space(space):
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is None:
dtype = self.datatype
## Case 1: x is a field
if isinstance(x, field):
if verbose:
......@@ -1525,14 +1548,14 @@ class point_space(space):
about.warnings.cflush(\
"WARNING: Getting data from foreign domain!")
## Extract the data, whatever it is, and cast it again
return self.cast(x.val)
return self.cast(x.val, dtype=dtype)
## Case 2: x is a distributed_data_object
if isinstance(x, distributed_data_object):
## Extract the data
temp = x.get_full_data()
## Cast the resulting numpy array again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
elif isinstance(x, np.ndarray):
## Check the shape
......@@ -1542,35 +1565,36 @@ class point_space(space):
## If the number of dof is equal or 1, use np.reshape...
temp = x.reshape(self.get_shape())
## ... and cast again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
elif x.size == 1:
temp = np.empty(shape = self.get_shape(),
dtype = self.datatype)
dtype = dtype)
temp[:] = x
return self.cast(temp)
return self.cast(temp, dtype=dtype)
else:
raise ValueError(about._errors.cstring(\
"ERROR: Data has incompatible shape!"))
## Check the datatype
if x.dtype < self.datatype:
if x.dtype != dtype:
about.warnings.cflush(\
"WARNING: Datatypes are uneqal/of conflicting precision (own: "\
+ str(self.datatype) + " <> foreign: " + str(x.dtype) \
+ str(dtype) + " <> foreign: " + str(x.dtype) \
+ ") and will be casted! "\
+ "Potential loss of precision!\n")
## Fix the datatype...
temp = x.astype(self.datatype)
temp = x.astype(dtype)
##... and cast again
return self.cast(temp)
return self.cast(temp, dtype=dtype)
return x
## Case 3: x is something else
## Use general numpy casting
else:
temp = np.empty(self.get_shape(), dtype = self.datatype)
temp[:] = x
temp = np.empty(self.get_shape(), dtype = dtype)
if x is not None:
temp[:] = x
return temp
......@@ -1592,12 +1616,12 @@ class point_space(space):
about.warnings.cprint("WARNING: enforce_shape is deprecated!")
x = np.array(x)
if(np.size(x)!=self.dim(split=False)):
raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(np.size(x))+" <> "+str(self.dim(split=False))+" )."))
# elif(not np.all(np.array(np.shape(x))==self.dim(split=True))):
if(np.size(x)!=self.get_dim(split=False)):
raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(np.size(x))+" <> "+str(self.get_dim(split=False))+" )."))
# elif(not np.all(np.array(np.shape(x))==self.get_dim(split=True))):
# about.warnings.cprint("WARNING: reshaping forced.")
return x.reshape(self.dim(split=True),order='C')
return x.reshape(self.get_dim(split=True),order='C')
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -1636,7 +1660,7 @@ class point_space(space):
else:
if(np.size(x)==1):
if(extend):
x = self.datatype(x)*np.ones(self.dim(split=True),dtype=self.datatype,order='C')
x = self.datatype(x)*np.ones(self.get_dim(split=True),dtype=self.datatype,order='C')
else:
if(np.isscalar(x)):
x = np.array([x],dtype=self.datatype)
......@@ -2011,14 +2035,22 @@ class point_space(space):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_real_Q(self, x):
try:
return x.isreal().all()
except(AttributeError):
return np.all(np.isreal(x))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def calc_bincount(self, x, weights=None, minlength=None):
if self.datamodel == 'np':
return np.bincount(x, weights=weights, minlength=minlength)
elif self.datamodel == 'd2o':
return x.bincount(weights=weights, minlength=minlength)
else:
raise NotImplementedError(about._errors.cstring(
"ERROR: function is not implemented for given datamodel."))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def get_plot(self,x,title="",vmin=None,vmax=None,unit="",norm=None,other=None,legend=False,**kwargs):
......@@ -2187,10 +2219,10 @@ class nested_space(space):
elif(isinstance(nn,nested_space)): ## no 2nd level nesting
for nn_ in nn.nest:
purenest.append(nn_)
pre_para = pre_para + [nn_.dim(split=True)]
pre_para = pre_para + [nn_.get_dim(split=True)]
else:
purenest.append(nn)
pre_para = pre_para + [nn.dim(split=True)]
pre_para = pre_para + [nn.get_dim(split=True)]
if(len(purenest)<2):
raise ValueError(about._errors.cstring("ERROR: invalid input."))
self.nest = purenest
......@@ -2333,14 +2365,14 @@ class nested_space(space):
if(self.datatype is not x.domain.datatype):
raise TypeError(about._errors.cstring("ERROR: inequal data types ( '"+str(np.result_type(self.datatype))+"' <> '"+str(np.result_type(x.domain.datatype))+"' )."))
else:
subshape = self.para[:-np.size(self.nest[-1].dim(split=True))]
subshape = self.para[:-np.size(self.nest[-1].get_dim(split=True))]
x = np.tensordot(np.ones(subshape,dtype=self.datatype,order='C'),x.val,axes=0)
elif(isinstance(x.domain,nested_space)):
if(self.datatype is not x.domain.datatype):
raise TypeError(about._errors.cstring("ERROR: inequal data types ( '"+str(np.result_type(self.datatype))+"' <> '"+str(np.result_type(x.domain.datatype))+"' )."))
else:
if(np.all(self.nest[-len(x.domain.nest):]==x.domain.nest)):
subshape = self.para[:np.sum([np.size(nn.dim(split=True)) for nn in self.nest[:-len(x.domain.nest)]],axis=0,dtype=np.int,out=None)]
subshape = self.para[:np.sum([np.size(nn.get_dim(split=True)) for nn in self.nest[:-len(x.domain.nest)]],axis=0,dtype=np.int,out=None)]
x = np.tensordot(np.ones(subshape,dtype=self.datatype,order='C'),x.val,axes=0)
else:
raise ValueError(about._errors.cstring("ERROR: inequal domains."))
......@@ -2360,18 +2392,18 @@ class nested_space(space):
if(np.ndim(x)<np.size(self.para)):
subshape = np.array([],dtype=np.int)
for ii in range(len(self.nest))[::-1]:
subshape = np.append(self.nest[ii].dim(split=True),subshape,axis=None)
subshape = np.append(self.nest[ii].get_dim(split=True),subshape,axis=None)
if(np.all(np.array(np.shape(x))==subshape)):
subshape = self.para[:np.sum([np.size(nn.dim(split=True)) for nn in self.nest[:ii]],axis=0,dtype=np.int,out=None)]
subshape = self.para[:np.sum([np.size(nn.get_dim(split=True)) for nn in self.nest[:ii]],axis=0,dtype=np.int,out=None)]
x = np.tensordot(np.ones(subshape,dtype=self.datatype,order='C'),x,axes=0)
break
else:
x = self.enforce_shape(x)
if(np.size(x)!=1):
subdim = np.prod(self.para[:-np.size(self.nest[-1].dim(split=True))],axis=0,dtype=np.int,out=None)
subdim = np.prod(self.para[:-np.size(self.nest[-1].get_dim(split=True))],axis=0,dtype=np.int,out=None)
## enforce special properties
x = x.reshape([subdim]+self.nest[-1].dim(split=True).tolist(),order='C')
x = x.reshape([subdim]+self.nest[-1].get_dim(split=True).tolist(),order='C')
x = np.array([self.nest[-1].enforce_values(xx,extend=True) for xx in x],dtype=self.datatype).reshape(self.para,order='C')
## check finiteness
......@@ -2418,23 +2450,23 @@ class nested_space(space):
arg = random.parse_arguments(self,**kwargs)
if(arg is None):
return np.zeros(self.dim(split=True),dtype=self.datatype,order='C')
return np.zeros(self.get_dim(split=True),dtype=self.datatype,order='C')
elif(arg[0]=="pm1"):
x = random.pm1(datatype=self.datatype,shape=self.dim(split=True))
x = random.pm1(datatype=self.datatype,shape=self.get_dim(split=True))
elif(arg[0]=="gau"):
x = random.gau(datatype=self.datatype,shape=self.dim(split=True),mean=None,dev=arg[2],var=arg[3])
x = random.gau(datatype=self.datatype,shape=self.get_dim(split=True),mean=None,dev=arg[2],var=arg[3])
elif(arg[0]=="uni"):
x = random.uni(datatype=self.datatype,shape=self.dim(split=True),vmin=arg[1],vmax=arg[2])
x = random.uni(datatype=self.datatype,shape=self.get_dim(split=True),vmin=arg[1],vmax=arg[2])
else:
raise KeyError(about._errors.cstring("ERROR: unsupported random key '"+str(arg[0])+"'."))
subdim = np.prod(self.para[:-np.size(self.nest[-1].dim(split=True))],axis=0,dtype=np.int,out=None)
subdim = np.prod(self.para[:-np.size(self.nest[-1].get_dim(split=True))],axis=0,dtype=np.int,out=None)
## enforce special properties
x = x.reshape([subdim]+self.nest[-1].dim(split=True).tolist(),order='C')
x = x.reshape([subdim]+self.nest[-1].get_dim(split=True).tolist(),order='C')
x = np.array([self.nest[-1].enforce_values(xx,extend=True) for xx in x],dtype=self.datatype).reshape(self.para,order='C')
return x
......@@ -2663,7 +2695,7 @@ class nested_space(space):
## analyse (sub)array
dotspace = None
subspace = None
if(np.size(y)==1)or(np.all(np.array(np.shape(y))==self.nest[-1].dim(split=True))):
if(np.size(y)==1)or(np.all(np.array(np.shape(y))==self.nest[-1].get_dim(split=True))):
dotspace = self.nest[-1]
if(len(self.nest)==2):
subspace = self.nest[0]
......@@ -2673,9 +2705,9 @@ class nested_space(space):
about.warnings.cprint("WARNING: computing (normal) inner product.")
return self.calc_dot(x,self.enforce_values(y,extend=True))
else:
dotshape = self.nest[-1].dim(split=True)
dotshape = self.nest[-1].get_dim(split=True)
for ii in range(len(self.nest)-1)[::-1]:
dotshape = np.append(self.nest[ii].dim(split=True),dotshape,axis=None)
dotshape = np.append(self.nest[ii].get_dim(split=True),dotshape,axis=None)
if(np.all(np.array(np.shape(y))==dotshape)):
dotspace = nested_space(self.nest[ii:])
if(ii<2):
......@@ -2692,8 +2724,8 @@ class nested_space(space):
if(not dotspace.discrete):
y = dotspace.calc_weight(y,power=1)
## pseudo inner product(s)
x = x.reshape([subspace.dim(split=False)]+dotspace.dim(split=True).tolist(),order='C')
pot = np.array([dotspace.calc_dot(xx,y) for xx in x],dtype=subspace.datatype).reshape(subspace.dim(split=True),order='C')
x = x.reshape([subspace.get_dim(split=False)]+dotspace.get_dim(split=True).tolist(),order='C')
pot = np.array([dotspace.calc_dot(xx,y) for xx in x],dtype=subspace.datatype).reshape(subspace.get_dim(split=True),order='C')
return field(subspace,val=pot,**kwargs)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -2739,10 +2771,10 @@ class nested_space(space):
elif(isinstance(codomain,nested_space)):
if(np.all(codomain.nest[:-1]==self.nest[:-1]))and(coorder is None):
## reshape
subdim = np.prod(self.para[:-np.size(self.nest[-1].dim(split=True))],axis=0,dtype=np.int,out=None)
x = x.reshape([subdim]+self.nest[-1].dim(split=True).tolist(),order='C')
subdim = np.prod(self.para[:-np.size(self.nest[-1].get_dim(split=True))],axis=0,dtype=np.int,out=None)
x = x.reshape([subdim]+self.nest[-1].get_dim(split=True).tolist(),order='C')
## transform
Tx = np.array([self.nest[-1].calc_transform(xx,codomain=codomain.nest[-1],**kwargs) for xx in x],dtype=codomain.datatype).reshape(codomain.dim(split=True),order='C')
Tx = np.array([self.nest[-1].calc_transform(xx,codomain=codomain.nest[-1],**kwargs) for xx in x],dtype=codomain.datatype).reshape(codomain.get_dim(split=True),order='C')
elif(len(codomain.nest)==len(self.nest)):#and(np.all([nn in self.nest for nn in codomain.nest]))and(np.all([nn in codomain.nest for nn in self.nest])):
## check coorder
if(coorder is None):
......@@ -2765,7 +2797,7 @@ class nested_space(space):
## compute axes permutation
lim = np.zeros((len(self.nest),2),dtype=np.int)
for ii in xrange(len(self.nest)):
lim[ii] = np.array([lim[ii-1][1],lim[ii-1][1]+np.size(self.nest[coorder[ii]].dim(split=True))])
lim[ii] = np.array([lim[ii-1][1],lim[ii-1][1]+np.size(self.nest[coorder[ii]].get_dim(split=True))])
lim = lim[coorder]
reorder = []
for ii in xrange(len(self.nest)):
......@@ -2823,10 +2855,10 @@ class nested_space(space):
return x
else:
## reshape
subdim = np.prod(self.para[:-np.size(self.nest[-1].dim(split=True))],axis=0,dtype=np.int,out=None)
x = x.reshape([subdim]+self.nest[-1].dim(split=True).tolist(),order='C')
subdim = np.prod(self.para[:-np.size(self.nest[-1].get_dim(split=True))],axis=0,dtype=np.int,out=None)
x = x.reshape([subdim]+self.nest[-1].get_dim(split=True).tolist(),order='C')
## smooth
return np.array([self.nest[-1].calc_smooth(xx,sigma=sigma,**kwargs) for xx in x],dtype=self.datatype).reshape(self.dim(split=True),order='C')
return np.array([self.nest[-1].calc_smooth(xx,sigma=sigma,**kwargs) for xx in x],dtype=self.datatype).reshape(self.get_dim(split=True),order='C')
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
......@@ -2938,7 +2970,7 @@ class field(object):
The space wherein the operator output lives (default: domain).
"""
def __init__(self, domain, val=None, codomain=None, idim=0, **kwargs):
def __init__(self, domain, val=None, codomain=None, ishape=None, **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -2973,9 +3005,26 @@ class field(object):
assert(self.domain.check_codomain(codomain))
self.codomain = codomain
self.idim = np.uint(idim)
if ishape is not None:
ishape = tuple(np.array(ishape, dtype = np.uint).flatten())
elif val is not None:
try:
if val.dtype.type == np.object_:
ishape = val.shape
else:
ishape = ()
except(AttributeError):
try:
ishape = val.ishape
except(AttributeError):
ishape = ()
else:
ishape = ()
self.ishape = ishape
if val == None:
if val is None:
if kwargs == {}:
val = self._map(lambda: self.domain.cast(0.))
else:
......@@ -2997,110 +3046,176 @@ class field(object):
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def _map(self, function, *args):
if self.idim == 0:
if self.ishape == ():
return function(*args)
else:
if args == ():
result = []
for i in xrange(self.idim):
result.append(function())
result = np.empty(self.ishape, dtype=np.object)
for i in xrange(np.prod(self.ishape)):
ii = np.unravel_index(i, self.ishape)
result[ii] = function()
return result
else:
return map(function, *args)
## define a helper function in order to clip the get-indices
## to be suitable for the foreign arrays in args.
## This allows you to do operations, like adding to fields
## with ishape (3,4,3) and (3,4,1)
def get_clipped(w, ind):
w_shape = np.array(np.shape(w))
get_tuple = tuple(np.clip(ind, 0, w_shape-1))
return w[get_tuple]
result = np.empty_like(args[0])
for i in xrange(np.prod(result.shape)):
ii = np.unravel_index(i, result.shape)
result[ii] = function(*map(
lambda z: get_clipped(z, ii), args)
)
#result[ii] = function(*map(lambda z: z[ii], args))
return result
def cast(self, x = None):
if self.idim == 0:
scalarized_x = self._cast_to_scalar_helper(x)
return self.domain.cast(scalarized_x)
def cast(self, x = None, ishape = None):
if ishape is None:
ishape = self.ishape
casted_x = self._cast_to_ishape(x, ishape=ishape)
if ishape == ():
return self.domain.cast(casted_x)
else:
vectorized_x = self._cast_to_vector_helper(x)
return self._map(lambda z: self.domain.cast(z),
vectorized_x)
casted_x)
def _cast_to_ishape(self, x, ishape = None):
if ishape is None:
ishape = self.ishape
if isinstance(x, field):
x = x.get_val()
if ishape == ():
casted_x = self._cast_to_scalar_helper(x)
else:
casted_x = self._cast_to_tensor_helper(x, ishape)
return casted_x
def _cast_to_scalar_helper(self, x):
## If x is a list, take the first entry and cast it with self.domain
if isinstance(x, list):
if len(x) >= 1:
if len(x) > 1:
about.warnings.cprint(
## if x is already a scalar or does fit directly, return it