Commit e167ce2c authored by csongor's avatar csongor

solve some casting issues

parent e30f7bc2
Pipeline #4994 skipped
......@@ -277,14 +277,6 @@ class space(object):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'apply_scalar_function'."))
def unary_operation(self, x, op=None):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'unary_operation'."))
def binary_operation(self, x, y, op=None):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'binary_operation'."))
def get_shape(self):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'shape'."))
......@@ -836,79 +828,6 @@ class point_space(space):
def apply_scalar_function(self, x, function, inplace=False):
return x.apply_scalar_function(function, inplace=inplace)
def unary_operation(self, x, op='None', axis=None, **kwargs):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
"""
translation = {'pos': lambda y: getattr(y, '__pos__')(),
'neg': lambda y: getattr(y, '__neg__')(),
'abs': lambda y: getattr(y, '__abs__')(),
'real': lambda y: getattr(y, 'real'),
'imag': lambda y: getattr(y, 'imag'),
'nanmin': lambda y: getattr(y, 'nanmin')(axis=axis),
'amin': lambda y: getattr(y, 'amin')(axis=axis),
'nanmax': lambda y: getattr(y, 'nanmax')(axis=axis),
'amax': lambda y: getattr(y, 'amax')(axis=axis),
'median': lambda y: getattr(y, 'median')(axis=axis),
'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(
axis=axis),
'argmin': lambda y: getattr(y, 'argmin')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(
axis=axis),
'argmax': lambda y: getattr(y, 'argmax')(axis=axis),
'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis),
'prod': lambda y: getattr(y, 'prod')(axis=axis),
'unique': lambda y: getattr(y, 'unique')(),
'copy': lambda y: getattr(y, 'copy')(),
'copy_empty': lambda y: getattr(y, 'copy_empty')(),
'isnan': lambda y: getattr(y, 'isnan')(),
'isinf': lambda y: getattr(y, 'isinf')(),
'isfinite': lambda y: getattr(y, 'isfinite')(),
'nan_to_num': lambda y: getattr(y, 'nan_to_num')(),
'all': lambda y: getattr(y, 'all')(axis=axis),
'any': lambda y: getattr(y, 'any')(axis=axis),
'None': lambda y: y}
return translation[op](x, **kwargs)
def binary_operation(self, x, y, op='None', cast=0):
translation = {'add': lambda z: getattr(z, '__add__'),
'radd': lambda z: getattr(z, '__radd__'),
'iadd': lambda z: getattr(z, '__iadd__'),
'sub': lambda z: getattr(z, '__sub__'),
'rsub': lambda z: getattr(z, '__rsub__'),
'isub': lambda z: getattr(z, '__isub__'),
'mul': lambda z: getattr(z, '__mul__'),
'rmul': lambda z: getattr(z, '__rmul__'),
'imul': lambda z: getattr(z, '__imul__'),
'div': lambda z: getattr(z, '__div__'),
'rdiv': lambda z: getattr(z, '__rdiv__'),
'idiv': lambda z: getattr(z, '__idiv__'),
'pow': lambda z: getattr(z, '__pow__'),
'rpow': lambda z: getattr(z, '__rpow__'),
'ipow': lambda z: getattr(z, '__ipow__'),
'ne': lambda z: getattr(z, '__ne__'),
'lt': lambda z: getattr(z, '__lt__'),
'le': lambda z: getattr(z, '__le__'),
'eq': lambda z: getattr(z, '__eq__'),
'ge': lambda z: getattr(z, '__ge__'),
'gt': lambda z: getattr(z, '__gt__'),
'None': lambda z: lambda u: u}
if (cast & 1) != 0:
x = self.cast(x)
if (cast & 2) != 0:
y = self.cast(y)
return translation[op](x)(y)
def get_shape(self):
return (self.paradict['num'],)
......
......@@ -212,9 +212,9 @@ class field(object):
if val is None:
if kwargs == {}:
val = map(lambda z: self.cast(z), (0,))
val = self.cast(0)
else:
val = map(lambda z: self.domain.get_random_values(
val = map(lambda z: self.get_random_values(domain = self.domain,
codomain=z, **kwargs), self.codomain)
self.set_val(new_val=val, copy=copy)
......@@ -281,6 +281,11 @@ class field(object):
self.codomain = codomain
return codomain
def get_random_values(self, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'enforce_power'."))
def __len__(self):
return int(self.get_dim()[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment