Commit 55564314 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

experimental performance tweaks

parent 79df4e2b
......@@ -631,10 +631,98 @@ class Field(object):
return 0.5*(1.+self.tanh())
return Field(self._domain, 0.5*(1.+np.tanh(self._uni)))
def __add__(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
if self._uni is None:
if other._uni is None:
return Field(self._domain, self._val+other._val)
if other._uni == 0:
return self
return Field(self._domain, self._val+other._uni)
else:
if self._uni == 0:
return other
if other._uni is None:
return Field(self._domain, other._val+self._uni)
return Field(self._domain, self._uni+other._uni)
if np.isscalar(other):
if self._uni is None:
return Field(self._domain, self._val+other)
return Field(self._domain, self._uni+other)
if isinstance(other, (dobj.data_object, np.ndarray)):
return Field(self._domain, self._val+other)
return NotImplemented
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
if self._uni is None:
if other._uni is None:
return Field(self._domain, self._val-other._val)
if other._uni == 0:
return self
return Field(self._domain, self._val-other._uni)
else:
if self._uni == 0:
return -other
if other._uni is None:
return Field(self._domain, self._uni-other._val)
return Field(self._domain, self._uni-other._uni)
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
if np.isscalar(other):
if self._uni is None:
return Field(self._domain, self._val-other)
return Field(self._domain, self._uni-other)
if isinstance(other, (dobj.data_object, np.ndarray)):
return Field(self._domain, self._val-other)
return NotImplemented
def __mul__(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
if self._uni is None:
if other._uni is None:
return Field(self._domain, self._val*other._val)
if other._uni == 1:
return self
if other._uni == 0:
return other
return Field(self._domain, self._val*other._uni)
else:
if self._uni == 1:
return other
if self._uni == 0:
return self
if other._uni is None:
return Field(self._domain, other._val*self._uni)
return Field(self._domain, self._uni*other._uni)
if np.isscalar(other):
if self._uni is None:
if other == 1:
return self
if other == 0:
return Field(self._domain, other)
return Field(self._domain, self._val*other)
return Field(self._domain, self._uni*other)
if isinstance(other, (dobj.data_object, np.ndarray)):
return Field(self._domain, self._val*other)
return NotImplemented
for op in ["__rsub__",
"__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
......
......@@ -94,11 +94,12 @@ class DiagonalOperator(EndomorphicOperator):
self._ldiag = self._ldiag.reshape(self._reshaper)
else:
self._ldiag = diagonal.local_data
self._update_diagmin()
self._fill_rest()
def _update_diagmin(self):
def _fill_rest(self):
self._ldiag.flags.writeable = False
if not np.issubdtype(self._ldiag.dtype, np.complexfloating):
self._complex = np.issubdtype(self._ldiag.dtype, np.complexfloating)
if not self._complex:
lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
......@@ -110,7 +111,7 @@ class DiagonalOperator(EndomorphicOperator):
else:
res._spaces = tuple(set(self._spaces) | set(spc))
res._ldiag = ldiag
res._update_diagmin()
res._fill_rest()
return res
def _scale(self, fct):
......@@ -137,21 +138,17 @@ class DiagonalOperator(EndomorphicOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
# shortcut for most common cases
if mode == 1 or (not self._complex and mode == 2):
return Field(x.domain, val=x.val*self._ldiag)
elif mode == self.ADJOINT_TIMES:
if np.issubdtype(self._ldiag.dtype, np.floating):
return Field(x.domain, val=x.val*self._ldiag)
else:
return Field(x.domain, val=x.val*self._ldiag.conj())
elif mode == self.INVERSE_TIMES:
return Field(x.domain, val=x.val/self._ldiag)
else:
if np.issubdtype(self._ldiag.dtype, np.floating):
return Field(x.domain, val=x.val/self._ldiag)
else:
return Field(x.domain, val=x.val/self._ldiag.conj())
xdiag = self._ldiag
if self._complex and (mode & 10): # adjoint or inverse adjoint
xdiag = xdiag.conj()
if mode & 3:
return Field(x.domain, val=x.val*xdiag)
return Field(x.domain, val=x.val/xdiag)
@property
def domain(self):
......@@ -162,23 +159,15 @@ class DiagonalOperator(EndomorphicOperator):
return self._all_ops
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating):
return self
if trafo == ADJ:
return self._from_ldiag((), self._ldiag.conjugate())
elif trafo == INV:
return self._from_ldiag((), 1./self._ldiag)
elif trafo == ADJ | INV:
return self._from_ldiag((), 1./self._ldiag.conjugate())
raise ValueError("invalid operator transformation")
xdiag = self._ldiag
if self._complex and (trafo & self.ADJOINT_BIT):
xdiag = xdiag.conj()
if trafo & self.INVERSE_BIT:
xdiag = 1./xdiag
return self._from_ldiag((), xdiag)
def draw_sample(self, from_inverse=False, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating):
if self._complex:
raise ValueError("operator not positive definite")
if self._diagmin < 0.:
raise ValueError("operator not positive definite")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment