Commit cf4f4e6f authored by Martin Reinecke's avatar Martin Reinecke

allow sample drawing with user-defined dtypes

parent 9021e88f
......@@ -20,6 +20,7 @@ from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.inversion_enabler import InversionEnabler
from ..field import Field, sqrt
from ..sugar import power_analyze, power_synthesize
import numpy as np
class WienerFilterCurvature(EndomorphicOperator):
......@@ -61,9 +62,9 @@ class WienerFilterCurvature(EndomorphicOperator):
def apply(self, x, mode):
return self._op.apply(x, mode)
def draw_sample(self):
n = self.N.draw_sample()
s = self.S.draw_sample()
def draw_sample(self, dtype=np.float64):
n = self.N.draw_sample(dtype)
s = self.S.draw_sample(dtype)
d = self.R(s) + n
......
......@@ -133,12 +133,11 @@ class DiagonalOperator(EndomorphicOperator):
return DiagonalOperator(self._diagonal.conjugate(), self._domain,
self._spaces)
def draw_sample(self):
def draw_sample(self, dtype=np.float64):
if np.issubdtype(self._ldiag.dtype, np.complexfloating):
raise ValueError("cannot draw sample from complex-valued operator")
res = Field.from_random(random_type="normal",
domain=self._domain,
dtype=self._diagonal.dtype)
res = Field.from_random(random_type="normal", domain=self._domain,
dtype=dtype)
res.local_data[()] *= np.sqrt(self._ldiag)
return res
......@@ -83,8 +83,10 @@ class ScalingOperator(EndomorphicOperator):
return self.TIMES | self.ADJOINT_TIMES
return self._all_ops
def draw_sample(self):
def draw_sample(self, dtype=np.float64):
if self._factor.imag != 0. or self._factor.real <= 0.:
raise ValueError("Operator not positive definite")
return Field.from_random(random_type="normal",
domain=self._domain,
std=np.sqrt(self._factor),
dtype=np.result_type(self._factor))
dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment