Skip to content
Snippets Groups Projects
Commit e9a5b0f1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

make inverse_draw_sample() largely obsolete

parent ad129166
No related branches found
No related tags found
1 merge request!237Replace InverseOperator and AdjointOperator with OperatorAdapter, and more
Pipeline #
......@@ -174,21 +174,14 @@ class DiagonalOperator(EndomorphicOperator):
raise ValueError("bad operator flipping mode")
return res
def draw_sample(self, dtype=np.float64):
def draw_sample(self, from_inverse=False, dtype=np.float64):
if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
(self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
res.local_data[()] *= np.sqrt(self._ldiag)
return res
def inverse_draw_sample(self, dtype=np.float64):
if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
(self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
res.local_data[()] /= np.sqrt(self._ldiag)
if from_inverse:
res.local_data[()] /= np.sqrt(self._ldiag)
else:
res.local_data[()] *= np.sqrt(self._ldiag)
return res
......@@ -36,12 +36,19 @@ class EndomorphicOperator(LinearOperator):
for endomorphic operators."""
return self.domain
def draw_sample(self, dtype=np.float64):
def draw_sample(self, from_inverse=False, dtype=np.float64):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Parameters
----------
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
dtype : numpy datatype (default : numpy.float64)
the data type to be used for the sample
Returns
-------
Field
......@@ -59,8 +66,4 @@ class EndomorphicOperator(LinearOperator):
-------
A sample from the Gaussian of given covariance
"""
if self.capability & self.INVERSE_TIMES:
x = self.draw_sample(dtype)
return self.inverse_times(x)
else:
raise NotImplementedError
return self.draw_sample(True, dtype)
......@@ -20,11 +20,11 @@ from ..minimization.quadratic_energy import QuadraticEnergy
from ..minimization.iteration_controller import IterationController
from ..field import Field
from ..logger import logger
from .linear_operator import LinearOperator
from .endomorphic_operator import EndomorphicOperator
import numpy as np
class InversionEnabler(LinearOperator):
class InversionEnabler(EndomorphicOperator):
"""Class which augments the capability of another operator object via
numerical inversion.
......@@ -80,14 +80,9 @@ class InversionEnabler(LinearOperator):
logger.warning("Error detected during operator inversion")
return r.position
def draw_sample(self, dtype=np.float64):
def draw_sample(self, from_inverse=False, dtype=np.float64):
try:
return self._op.draw_sample(dtype)
return self._op.draw_sample(from_inverse, dtype)
except:
return self(self._op.inverse_draw_sample(dtype))
def inverse_draw_sample(self, dtype=np.float64):
try:
return self._op.inverse_draw_sample(dtype)
except:
return self.inverse_times(self._op.draw_sample(dtype))
samp = self._op.draw_sample(not from_inverse, dtype)
return self.inverse_times(samp) if from_inverse else self(samp)
......@@ -49,12 +49,7 @@ class OperatorAdapter(LinearOperator):
def apply(self, x, mode):
return self._op.apply(x, self._modeTable[self._mode][self._ilog[mode]])
def draw_sample(self, dtype=np.float64):
def draw_sample(self, from_inverse=False, dtype=np.float64):
if self._mode & self.INVERSE_BIT:
return self._op.inverse_draw_sample(dtype)
return self._op.draw_sample(dtype)
def inverse_draw_sample(self, dtype=np.float64):
if self._mode & self.INVERSE_BIT:
return self._op.draw_sample(dtype)
return self._op.inverse_draw_sample(dtype)
return self._op.draw_sample(not from_inverse, dtype)
return self._op.draw_sample(from_inverse, dtype)
......@@ -48,5 +48,8 @@ class SandwichOperator(EndomorphicOperator):
def apply(self, x, mode):
return self._op.apply(x, mode)
def draw_sample(self, dtype=np.float64):
return self._bun.adjoint_times(self._cheese.draw_sample(dtype))
def draw_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise ValueError("cannot draw from inverse of this operator")
return self._bun.adjoint_times(
self._cheese.draw_sample(from_inverse, dtype))
......@@ -93,14 +93,10 @@ class ScalingOperator(EndomorphicOperator):
def capability(self):
return self._all_ops
def _sample_helper(self, fct, dtype):
def draw_sample(self, from_inverse=False, dtype=np.float64):
fct = self._factor
if fct.imag != 0. or fct.real <= 0.:
raise ValueError("operator not positive definite")
fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
return Field.from_random(
random_type="normal", domain=self._domain, std=fct, dtype=dtype)
def draw_sample(self, dtype=np.float64):
return self._sample_helper(np.sqrt(self._factor), dtype)
def inverse_draw_sample(self, dtype=np.float64):
return self._sample_helper(1./np.sqrt(self._factor), dtype)
......@@ -143,8 +143,10 @@ class SumOperator(LinearOperator):
res += op.apply(x, mode)
return res
def draw_sample(self, dtype=np.float64):
res = self._ops[0].draw_sample(dtype)
def draw_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise ValueError("cannot draw from inverse of this operator")
res = self._ops[0].draw_sample(from_inverse, dtype)
for op in self._ops[1:]:
res += op.draw_sample(dtype)
res += op.draw_sample(from_inverse, dtype)
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment