From e9a5b0f18b33c1e16238721eebd815b54f98efec Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Sun, 1 Apr 2018 13:55:42 +0200 Subject: [PATCH] make inverse_draw_sample() largely obsolete --- nifty4/operators/diagonal_operator.py | 17 +++++------------ nifty4/operators/endomorphic_operator.py | 15 +++++++++------ nifty4/operators/inversion_enabler.py | 17 ++++++----------- nifty4/operators/operator_adapter.py | 11 +++-------- nifty4/operators/sandwich_operator.py | 7 +++++-- nifty4/operators/scaling_operator.py | 10 +++------- nifty4/operators/sum_operator.py | 8 +++++--- 7 files changed, 36 insertions(+), 49 deletions(-) diff --git a/nifty4/operators/diagonal_operator.py b/nifty4/operators/diagonal_operator.py index c2d0b5dd7..9c3652226 100644 --- a/nifty4/operators/diagonal_operator.py +++ b/nifty4/operators/diagonal_operator.py @@ -174,21 +174,14 @@ class DiagonalOperator(EndomorphicOperator): raise ValueError("bad operator flipping mode") return res - def draw_sample(self, dtype=np.float64): + def draw_sample(self, from_inverse=False, dtype=np.float64): if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or (self._ldiag <= 0.).any()): raise ValueError("operator not positive definite") res = Field.from_random(random_type="normal", domain=self._domain, dtype=dtype) - res.local_data[()] *= np.sqrt(self._ldiag) - return res - - def inverse_draw_sample(self, dtype=np.float64): - if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or - (self._ldiag <= 0.).any()): - raise ValueError("operator not positive definite") - - res = Field.from_random(random_type="normal", domain=self._domain, - dtype=dtype) - res.local_data[()] /= np.sqrt(self._ldiag) + if from_inverse: + res.local_data[()] /= np.sqrt(self._ldiag) + else: + res.local_data[()] *= np.sqrt(self._ldiag) return res diff --git a/nifty4/operators/endomorphic_operator.py b/nifty4/operators/endomorphic_operator.py index e456923c0..0046d8986 100644 --- a/nifty4/operators/endomorphic_operator.py +++ b/nifty4/operators/endomorphic_operator.py @@ -36,12 +36,19 @@ class EndomorphicOperator(LinearOperator): for endomorphic operators.""" return self.domain - def draw_sample(self, dtype=np.float64): + def draw_sample(self, from_inverse=False, dtype=np.float64): """Generate a zero-mean sample Generates a sample from a Gaussian distribution with zero mean and covariance given by the operator. + Parameters + ---------- + from_inverse : bool (default : False) + if True, the sample is drawn from the inverse of the operator + dtype : numpy datatype (default : numpy.float64) + the data type to be used for the sample + Returns ------- Field @@ -59,8 +66,4 @@ class EndomorphicOperator(LinearOperator): ------- A sample from the Gaussian of given covariance """ - if self.capability & self.INVERSE_TIMES: - x = self.draw_sample(dtype) - return self.inverse_times(x) - else: - raise NotImplementedError + return self.draw_sample(True, dtype) diff --git a/nifty4/operators/inversion_enabler.py b/nifty4/operators/inversion_enabler.py index 774318cb8..dbdd59314 100644 --- a/nifty4/operators/inversion_enabler.py +++ b/nifty4/operators/inversion_enabler.py @@ -20,11 +20,11 @@ from ..minimization.quadratic_energy import QuadraticEnergy from ..minimization.iteration_controller import IterationController from ..field import Field from ..logger import logger -from .linear_operator import LinearOperator +from .endomorphic_operator import EndomorphicOperator import numpy as np -class InversionEnabler(LinearOperator): +class InversionEnabler(EndomorphicOperator): """Class which augments the capability of another operator object via numerical inversion. @@ -80,14 +80,9 @@ class InversionEnabler(LinearOperator): logger.warning("Error detected during operator inversion") return r.position - def draw_sample(self, dtype=np.float64): + def draw_sample(self, from_inverse=False, dtype=np.float64): try: - return self._op.draw_sample(dtype) + return self._op.draw_sample(from_inverse, dtype) except: - return self(self._op.inverse_draw_sample(dtype)) - - def inverse_draw_sample(self, dtype=np.float64): - try: - return self._op.inverse_draw_sample(dtype) - except: - return self.inverse_times(self._op.draw_sample(dtype)) + samp = self._op.draw_sample(not from_inverse, dtype) + return self.inverse_times(samp) if from_inverse else self(samp) diff --git a/nifty4/operators/operator_adapter.py b/nifty4/operators/operator_adapter.py index 61ce31865..036045e79 100644 --- a/nifty4/operators/operator_adapter.py +++ b/nifty4/operators/operator_adapter.py @@ -49,12 +49,7 @@ class OperatorAdapter(LinearOperator): def apply(self, x, mode): return self._op.apply(x, self._modeTable[self._mode][self._ilog[mode]]) - def draw_sample(self, dtype=np.float64): + def draw_sample(self, from_inverse=False, dtype=np.float64): if self._mode & self.INVERSE_BIT: - return self._op.inverse_draw_sample(dtype) - return self._op.draw_sample(dtype) - - def inverse_draw_sample(self, dtype=np.float64): - if self._mode & self.INVERSE_BIT: - return self._op.draw_sample(dtype) - return self._op.inverse_draw_sample(dtype) + return self._op.draw_sample(not from_inverse, dtype) + return self._op.draw_sample(from_inverse, dtype) diff --git a/nifty4/operators/sandwich_operator.py b/nifty4/operators/sandwich_operator.py index e324b135d..464072adb 100644 --- a/nifty4/operators/sandwich_operator.py +++ b/nifty4/operators/sandwich_operator.py @@ -48,5 +48,8 @@ class SandwichOperator(EndomorphicOperator): def apply(self, x, mode): return self._op.apply(x, mode) - def draw_sample(self, dtype=np.float64): - return self._bun.adjoint_times(self._cheese.draw_sample(dtype)) + def draw_sample(self, from_inverse=False, dtype=np.float64): + if from_inverse: + raise ValueError("cannot draw from inverse of this operator") + return self._bun.adjoint_times( + self._cheese.draw_sample(from_inverse, dtype)) diff --git a/nifty4/operators/scaling_operator.py b/nifty4/operators/scaling_operator.py index ddb8eaaae..cce8fee4b 100644 --- a/nifty4/operators/scaling_operator.py +++ b/nifty4/operators/scaling_operator.py @@ -93,14 +93,10 @@ class ScalingOperator(EndomorphicOperator): def capability(self): return self._all_ops - def _sample_helper(self, fct, dtype): + def draw_sample(self, from_inverse=False, dtype=np.float64): + fct = self._factor if fct.imag != 0. or fct.real <= 0.: raise ValueError("operator not positive definite") + fct = 1./np.sqrt(fct) if from_inverse else np.sqrt(fct) return Field.from_random( random_type="normal", domain=self._domain, std=fct, dtype=dtype) - - def draw_sample(self, dtype=np.float64): - return self._sample_helper(np.sqrt(self._factor), dtype) - - def inverse_draw_sample(self, dtype=np.float64): - return self._sample_helper(1./np.sqrt(self._factor), dtype) diff --git a/nifty4/operators/sum_operator.py b/nifty4/operators/sum_operator.py index be36836af..e2d5d70f3 100644 --- a/nifty4/operators/sum_operator.py +++ b/nifty4/operators/sum_operator.py @@ -143,8 +143,10 @@ class SumOperator(LinearOperator): res += op.apply(x, mode) return res - def draw_sample(self, dtype=np.float64): - res = self._ops[0].draw_sample(dtype) + def draw_sample(self, from_inverse=False, dtype=np.float64): + if from_inverse: + raise ValueError("cannot draw from inverse of this operator") + res = self._ops[0].draw_sample(from_inverse, dtype) for op in self._ops[1:]: - res += op.draw_sample(dtype) + res += op.draw_sample(from_inverse, dtype) return res -- GitLab