Commit e9a5b0f1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

make inverse_draw_sample() largely obsolete

parent ad129166
Pipeline #26707 passed with stage
in 8 minutes and 27 seconds
...@@ -174,21 +174,14 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -174,21 +174,14 @@ class DiagonalOperator(EndomorphicOperator):
raise ValueError("bad operator flipping mode") raise ValueError("bad operator flipping mode")
return res return res
def draw_sample(self, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
(self._ldiag <= 0.).any()): (self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite") raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain, res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype) dtype=dtype)
res.local_data[()] *= np.sqrt(self._ldiag) if from_inverse:
return res res.local_data[()] /= np.sqrt(self._ldiag)
else:
def inverse_draw_sample(self, dtype=np.float64): res.local_data[()] *= np.sqrt(self._ldiag)
if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or
(self._ldiag <= 0.).any()):
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
res.local_data[()] /= np.sqrt(self._ldiag)
return res return res
...@@ -36,12 +36,19 @@ class EndomorphicOperator(LinearOperator): ...@@ -36,12 +36,19 @@ class EndomorphicOperator(LinearOperator):
for endomorphic operators.""" for endomorphic operators."""
return self.domain return self.domain
def draw_sample(self, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
"""Generate a zero-mean sample """Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator. covariance given by the operator.
Parameters
----------
from_inverse : bool (default : False)
if True, the sample is drawn from the inverse of the operator
dtype : numpy datatype (default : numpy.float64)
the data type to be used for the sample
Returns Returns
------- -------
Field Field
...@@ -59,8 +66,4 @@ class EndomorphicOperator(LinearOperator): ...@@ -59,8 +66,4 @@ class EndomorphicOperator(LinearOperator):
------- -------
A sample from the Gaussian of given covariance A sample from the Gaussian of given covariance
""" """
if self.capability & self.INVERSE_TIMES: return self.draw_sample(True, dtype)
x = self.draw_sample(dtype)
return self.inverse_times(x)
else:
raise NotImplementedError
...@@ -20,11 +20,11 @@ from ..minimization.quadratic_energy import QuadraticEnergy ...@@ -20,11 +20,11 @@ from ..minimization.quadratic_energy import QuadraticEnergy
from ..minimization.iteration_controller import IterationController from ..minimization.iteration_controller import IterationController
from ..field import Field from ..field import Field
from ..logger import logger from ..logger import logger
from .linear_operator import LinearOperator from .endomorphic_operator import EndomorphicOperator
import numpy as np import numpy as np
class InversionEnabler(LinearOperator): class InversionEnabler(EndomorphicOperator):
"""Class which augments the capability of another operator object via """Class which augments the capability of another operator object via
numerical inversion. numerical inversion.
...@@ -80,14 +80,9 @@ class InversionEnabler(LinearOperator): ...@@ -80,14 +80,9 @@ class InversionEnabler(LinearOperator):
logger.warning("Error detected during operator inversion") logger.warning("Error detected during operator inversion")
return r.position return r.position
def draw_sample(self, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
try: try:
return self._op.draw_sample(dtype) return self._op.draw_sample(from_inverse, dtype)
except: except:
return self(self._op.inverse_draw_sample(dtype)) samp = self._op.draw_sample(not from_inverse, dtype)
return self.inverse_times(samp) if from_inverse else self(samp)
def inverse_draw_sample(self, dtype=np.float64):
try:
return self._op.inverse_draw_sample(dtype)
except:
return self.inverse_times(self._op.draw_sample(dtype))
...@@ -49,12 +49,7 @@ class OperatorAdapter(LinearOperator): ...@@ -49,12 +49,7 @@ class OperatorAdapter(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
return self._op.apply(x, self._modeTable[self._mode][self._ilog[mode]]) return self._op.apply(x, self._modeTable[self._mode][self._ilog[mode]])
def draw_sample(self, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
if self._mode & self.INVERSE_BIT: if self._mode & self.INVERSE_BIT:
return self._op.inverse_draw_sample(dtype) return self._op.draw_sample(not from_inverse, dtype)
return self._op.draw_sample(dtype) return self._op.draw_sample(from_inverse, dtype)
def inverse_draw_sample(self, dtype=np.float64):
if self._mode & self.INVERSE_BIT:
return self._op.draw_sample(dtype)
return self._op.inverse_draw_sample(dtype)
...@@ -48,5 +48,8 @@ class SandwichOperator(EndomorphicOperator): ...@@ -48,5 +48,8 @@ class SandwichOperator(EndomorphicOperator):
def apply(self, x, mode): def apply(self, x, mode):
return self._op.apply(x, mode) return self._op.apply(x, mode)
def draw_sample(self, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
return self._bun.adjoint_times(self._cheese.draw_sample(dtype)) if from_inverse:
raise ValueError("cannot draw from inverse of this operator")
return self._bun.adjoint_times(
self._cheese.draw_sample(from_inverse, dtype))
...@@ -93,14 +93,10 @@ class ScalingOperator(EndomorphicOperator): ...@@ -93,14 +93,10 @@ class ScalingOperator(EndomorphicOperator):
def capability(self): def capability(self):
return self._all_ops return self._all_ops
def _sample_helper(self, fct, dtype): def draw_sample(self, from_inverse=False, dtype=np.float64):
fct = self._factor
if fct.imag != 0. or fct.real <= 0.: if fct.imag != 0. or fct.real <= 0.:
raise ValueError("operator not positive definite") raise ValueError("operator not positive definite")
fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct)
return Field.from_random( return Field.from_random(
random_type="normal", domain=self._domain, std=fct, dtype=dtype) random_type="normal", domain=self._domain, std=fct, dtype=dtype)
def draw_sample(self, dtype=np.float64):
return self._sample_helper(np.sqrt(self._factor), dtype)
def inverse_draw_sample(self, dtype=np.float64):
return self._sample_helper(1./np.sqrt(self._factor), dtype)
...@@ -143,8 +143,10 @@ class SumOperator(LinearOperator): ...@@ -143,8 +143,10 @@ class SumOperator(LinearOperator):
res += op.apply(x, mode) res += op.apply(x, mode)
return res return res
def draw_sample(self, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
res = self._ops[0].draw_sample(dtype) if from_inverse:
raise ValueError("cannot draw from inverse of this operator")
res = self._ops[0].draw_sample(from_inverse, dtype)
for op in self._ops[1:]: for op in self._ops[1:]:
res += op.draw_sample(dtype) res += op.draw_sample(from_inverse, dtype)
return res return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment