Commit 67eb41ab authored by clienhar's avatar clienhar
Browse files

implemented inverse_draw_sample()

- moved draw_sample() and inverse_draw_sample() to EndomorphicOperator
- inverse_draw_sample() is always possible if draw_samples() is implemented (see EndomorphicOperator)
- changed WienerCurvature draw_sample() to probably better version
- implemented both functions in DiagonalOperator and ScalingOperator properly (i.e. check for positive definiteness)
parent 37e802de
...@@ -60,12 +60,8 @@ class WienerFilterCurvature(EndomorphicOperator): ...@@ -60,12 +60,8 @@ class WienerFilterCurvature(EndomorphicOperator):
def apply(self, x, mode): def apply(self, x, mode):
return self._op.apply(x, mode) return self._op.apply(x, mode)
def draw_inverse_sample(self, dtype=np.float64): def draw_sample(self, dtype=np.float64):
n = self.N.draw_sample(dtype) n = self.N.inverse_draw_sample(dtype)
s = self.S.draw_sample(dtype) s = self.S.inverse_draw_sample(dtype)
d = self.R(s) + n return s - self.R.adjoint_times(n)
j = self.R.adjoint_times(self.N.inverse_times(d))
m = self.inverse_times(j)
return s - m
...@@ -147,10 +147,18 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -147,10 +147,18 @@ class DiagonalOperator(EndomorphicOperator):
return res return res
def draw_sample(self, dtype=np.float64): def draw_sample(self, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating): if np.issubdtype(self._ldiag.dtype, np.complexfloating) or (self._ldiag <= 0.).any():
raise ValueError("cannot draw sample from complex-valued operator") raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain, res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype) dtype=dtype)
res.local_data[()] *= np.sqrt(self._ldiag) res.local_data[()] *= np.sqrt(self._ldiag)
return res return res
def inverse_draw_sample(self, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating) or (self._ldiag <= 0.).any():
raise ValueError("operator not positive definite")
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
res.local_data[()] /= np.sqrt(self._ldiag)
return res
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
import numpy as np
class EndomorphicOperator(LinearOperator): class EndomorphicOperator(LinearOperator):
...@@ -34,3 +35,32 @@ class EndomorphicOperator(LinearOperator): ...@@ -34,3 +35,32 @@ class EndomorphicOperator(LinearOperator):
Returns `self.domain`, because this is also the target domain Returns `self.domain`, because this is also the target domain
for endomorphic operators.""" for endomorphic operators."""
return self.domain return self.domain
def draw_sample(self, dtype=np.float64):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Returns
-------
Field
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
def inverse_draw_sample(self, dtype=np.float64):
"""Generates a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the inverse of the operator.
Returns
-------
A sample from the Gaussian of given covariance
"""
if self.capability & self.INVERSE_TIMES:
x = self.draw_sample(dtype=dtype)
return self.inverse_times(x)
else:
raise NotImplementedError
...@@ -264,16 +264,3 @@ class LinearOperator(NiftyMetaBase()): ...@@ -264,16 +264,3 @@ class LinearOperator(NiftyMetaBase()):
self._check_mode(mode) self._check_mode(mode)
if x.domain != self._dom(mode): if x.domain != self._dom(mode):
raise ValueError("The operator's and field's domains don't match.") raise ValueError("The operator's and field's domains don't match.")
def draw_sample(self):
"""Generate a zero-mean sample
Generates a sample from a Gaussian distribution with zero mean and
covariance given by the operator.
Returns
-------
Field
A sample from the Gaussian of given covariance.
"""
raise NotImplementedError
...@@ -100,8 +100,17 @@ class ScalingOperator(EndomorphicOperator): ...@@ -100,8 +100,17 @@ class ScalingOperator(EndomorphicOperator):
def draw_sample(self, dtype=np.float64): def draw_sample(self, dtype=np.float64):
if self._factor.imag != 0. or self._factor.real <= 0.: if self._factor.imag != 0. or self._factor.real <= 0.:
raise ValueError("Operator not positive definite") raise ValueError("operator not positive definite")
return Field.from_random(random_type="normal", return Field.from_random(random_type="normal",
domain=self._domain, domain=self._domain,
std=np.sqrt(self._factor), std=np.sqrt(self._factor),
dtype=dtype) dtype=dtype)
def inverse_draw_sample(self, dtype=np.float64):
if self._factor.imag != 0. or self._factor.real <= 0.:
raise ValueError("operator not positive definite")
return Field.from_random(random_type="normal",
domain=self._domain,
std=1./np.sqrt(self._factor),
dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment