Commit 369c6e7c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

be more restrictive with binary Field operations

parent 96ea43f4
......@@ -204,9 +204,6 @@ class data_object(object):
elif np.isscalar(other):
a = a._data
b = other
elif isinstance(other, np.ndarray):
a = a._data
b = other
else:
return NotImplemented
......
......@@ -654,8 +654,6 @@ class Field(object):
if self._uni is None:
return Field(self._domain, self._val+other)
return Field(self._domain, self._uni+other)
if isinstance(other, (dobj.data_object, np.ndarray)):
return Field(self._domain, self._val+other)
return NotImplemented
def __radd__(self, other):
......@@ -683,8 +681,6 @@ class Field(object):
if self._uni is None:
return Field(self._domain, self._val-other)
return Field(self._domain, self._uni-other)
if isinstance(other, (dobj.data_object, np.ndarray)):
return Field(self._domain, self._val-other)
return NotImplemented
def __mul__(self, other):
......@@ -717,8 +713,6 @@ class Field(object):
return Field(self._domain, other)
return Field(self._domain, self._val*other)
return Field(self._domain, self._uni*other)
if isinstance(other, (dobj.data_object, np.ndarray)):
return Field(self._domain, self._val*other)
return NotImplemented
......@@ -737,12 +731,9 @@ for op in ["__rsub__",
raise ValueError("domains are incompatible.")
tval = getattr(self._val, op)(other._val)
return Field(self._domain, tval)
if (np.isscalar(other) or
isinstance(other, (dobj.data_object, np.ndarray))):
if np.isscalar(other):
tval = getattr(self._val, op)(other)
return Field(self._domain, tval)
return NotImplemented
return func2
setattr(Field, op, func(op))
......
......@@ -140,15 +140,15 @@ class DiagonalOperator(EndomorphicOperator):
self._check_input(x, mode)
# shortcut for most common cases
if mode == 1 or (not self._complex and mode == 2):
return Field(x.domain, val=x.val*self._ldiag)
return Field.from_local_data(x.domain, x.local_data*self._ldiag)
xdiag = self._ldiag
if self._complex and (mode & 10): # adjoint or inverse adjoint
xdiag = xdiag.conj()
if mode & 3:
return Field(x.domain, val=x.val*xdiag)
return Field(x.domain, val=x.val/xdiag)
return Field.from_local_data(x.domain, x.local_data*xdiag)
return Field.from_local_data(x.domain, x.local_data/xdiag)
@property
def domain(self):
......@@ -176,6 +176,7 @@ class DiagonalOperator(EndomorphicOperator):
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
if from_inverse:
return res/np.sqrt(self._ldiag)
res = res.local_data/np.sqrt(self._ldiag)
else:
return res*np.sqrt(self._ldiag)
res = res.local_data*np.sqrt(self._ldiag)
return Field.from_local_data(self._domain, res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment