Commit 723089ab authored by Martin Reinecke's avatar Martin Reinecke
Browse files

improve checks when drawing from DiagonalOperator

parent 91944cd3
Pipeline #28885 passed with stages
in 10 minutes and 23 seconds
......@@ -322,6 +322,12 @@ def np_allreduce_sum(arr):
return res
def np_allreduce_min(arr):
res = np.empty_like(arr)
_comm.Allreduce(arr, res, MPI.MIN)
return res
def distaxis(arr):
return arr._distaxis
......
......@@ -70,6 +70,10 @@ def np_allreduce_sum(arr):
return arr
def np_allreduce_min(arr):
return arr
def distaxis(arr):
return -1
......
......@@ -92,45 +92,46 @@ class DiagonalOperator(EndomorphicOperator):
self._ldiag = self._ldiag.reshape(self._reshaper)
else:
self._ldiag = diagonal.local_data
self._update_diagmin()
def _update_diagmin(self):
self._ldiag.flags.writeable = False
if not np.issubdtype(self._ldiag.dtype, np.complexfloating):
lmin = self._ldiag.min() if self._ldiag.size > 0 else 1.
self._diagmin = dobj.np_allreduce_min(np.array(lmin))[()]
def _skeleton(self, spc):
def _from_ldiag(self, spc, ldiag):
res = DiagonalOperator.__new__(DiagonalOperator)
res._domain = self._domain
if self._spaces is None or spc is None:
res._spaces = None
else:
res._spaces = tuple(set(self._spaces) | set(spc))
res._ldiag = ldiag
res._update_diagmin()
return res
def _scale(self, fct):
if not np.isscalar(fct):
raise TypeError("scalar value required")
res = self._skeleton(())
res._ldiag = self._ldiag*fct
return res
return self._from_ldiag((), self._ldiag*fct)
def _add(self, sum):
if not np.isscalar(sum):
raise TypeError("scalar value required")
res = self._skeleton(())
res._ldiag = self._ldiag + sum
return res
return self._from_ldiag((), self._ldiag+sum)
def _combine_prod(self, op):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
res = self._skeleton(op._spaces)
res._ldiag = self._ldiag*op._ldiag
return res
return self._from_ldiag(op._spaces, self._ldiag*op._ldiag)
def _combine_sum(self, op, selfneg, opneg):
if not isinstance(op, DiagonalOperator):
raise TypeError("DiagonalOperator required")
res = self._skeleton(op._spaces)
res._ldiag = (self._ldiag * (-1 if selfneg else 1) +
op._ldiag * (-1 if opneg else 1))
return res
tdiag = (self._ldiag * (-1 if selfneg else 1) +
op._ldiag * (-1 if opneg else 1))
return self._from_ldiag(op._spaces, tdiag)
def apply(self, x, mode):
self._check_input(x, mode)
......@@ -166,20 +167,21 @@ class DiagonalOperator(EndomorphicOperator):
return self
if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating):
return self
res = self._skeleton(())
if trafo == ADJ:
res._ldiag = self._ldiag.conjugate()
return self._from_ldiag((), self._ldiag.conjugate())
elif trafo == INV:
res._ldiag = 1./self._ldiag
return self._from_ldiag((), 1./self._ldiag)
elif trafo == ADJ | INV:
res._ldiag = 1./self._ldiag.conjugate()
else:
raise ValueError("invalid operator transformation")
return res
return self._from_ldiag((), 1./self._ldiag.conjugate())
raise ValueError("invalid operator transformation")
def draw_sample(self, from_inverse=False, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating):
raise ValueError("operator not positive definite")
if self._diagmin < 0.:
raise ValueError("operator not positive definite")
if self._diagmin == 0. and from_inverse:
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
if from_inverse:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment