Commit ecd73d29 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add get_sqrt() to the most important operators

parent f4703ca5
...@@ -85,6 +85,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -85,6 +85,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
only_r_linear) only_r_linear)
_full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol, _full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
rtol, only_r_linear) rtol, only_r_linear)
_check_sqrt(op, domain_dtype)
_check_sqrt(op.adjoint, target_dtype)
_check_sqrt(op.inverse, target_dtype)
_check_sqrt(op.adjoint.inverse, domain_dtype)
def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True, def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
...@@ -197,6 +201,23 @@ def _domain_check_linear(op, domain_dtype=None, inp=None): ...@@ -197,6 +201,23 @@ def _domain_check_linear(op, domain_dtype=None, inp=None):
myassert(op(inp).domain is op.target) myassert(op(inp).domain is op.target)
def _check_sqrt(op, domain_dtype):
if not is_endo(op):
try:
op.get_sqrt()
raise RuntimeError("Operator implements get_sqrt() although it is not an endomorphic operator.")
except AttributeError:
return
try:
sqop = op.get_sqrt()
except (NotImplementedError, AttributeError):
return
fld = from_random(op.domain, dtype=domain_dtype)
a = op(fld)
b = (sqop.adjoint @ sqop)(fld)
return assert_allclose(a, b, rtol=1e-15)
def _domain_check_nonlinear(op, loc): def _domain_check_nonlinear(op, loc):
_domain_check(op) _domain_check(op)
myassert(isinstance(loc, (Field, MultiField))) myassert(isinstance(loc, (Field, MultiField)))
......
...@@ -48,7 +48,14 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -48,7 +48,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
raise TypeError("LinearOperator expected") raise TypeError("LinearOperator expected")
def get_sqrt(self): def get_sqrt(self):
ops = {kk: vv.sqrt() for kk, vv in self._ops.items() if vv is not None} ops = {}
for ii, kk in enumerate(self._domain.keys()):
if self._ops[ii] is None:
continue
try:
ops[kk] = self._ops[ii].get_sqrt()
except AttributeError:
raise NotImplementedError
return BlockDiagonalOperator(self._domain, ops) return BlockDiagonalOperator(self._domain, ops)
def apply(self, x, mode): def apply(self, x, mode):
......
...@@ -166,5 +166,10 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -166,5 +166,10 @@ class DiagonalOperator(EndomorphicOperator):
res = Field.from_random(domain=self._domain, random_type="normal", dtype=dtype) res = Field.from_random(domain=self._domain, random_type="normal", dtype=dtype)
return self.process_sample(res, from_inverse) return self.process_sample(res, from_inverse)
def get_sqrt(self):
if not np.iscomplexobj(self._ldiag) or (self._ldiag < 0).any():
raise NotImplementedError
return self._from_ldiag(None, np.sqrt(self._ldiag))
def __repr__(self): def __repr__(self):
return "DiagonalOperator" return "DiagonalOperator"
...@@ -75,6 +75,20 @@ class EndomorphicOperator(LinearOperator): ...@@ -75,6 +75,20 @@ class EndomorphicOperator(LinearOperator):
""" """
raise NotImplementedError raise NotImplementedError
def get_sqrt(self):
"""Return operator op which obeys `self == op.adjoint @ op`.
Note that this function is only implemented for operators with real
spectrum.
Returns
-------
EndomorphicOperator
Operator which is the square root of `self`
"""
raise NotImplementedError
def _dom(self, mode): def _dom(self, mode):
return self._domain return self._domain
......
...@@ -93,6 +93,11 @@ class SandwichOperator(EndomorphicOperator): ...@@ -93,6 +93,11 @@ class SandwichOperator(EndomorphicOperator):
return self._bun.adjoint_times( return self._bun.adjoint_times(
self._cheese.draw_sample(from_inverse)) self._cheese.draw_sample(from_inverse))
def get_sqrt(self):
if self._cheese is None:
return self._bun
return self._cheese.get_sqrt() @ self._bun
def __repr__(self): def __repr__(self):
from ..utilities import indent from ..utilities import indent
return "\n".join(( return "\n".join((
......
...@@ -95,6 +95,12 @@ class ScalingOperator(EndomorphicOperator): ...@@ -95,6 +95,12 @@ class ScalingOperator(EndomorphicOperator):
from ..sugar import from_random from ..sugar import from_random
return from_random(domain=self._domain, random_type="normal", dtype=dtype, std=self._get_fct(from_inverse)) return from_random(domain=self._domain, random_type="normal", dtype=dtype, std=self._get_fct(from_inverse))
def get_sqrt(self):
fct = self._get_fct(False)
if np.iscomplexobj(fct) or fct < 0:
raise NotImplementedError
return ScalingOperator(self._domain, fct)
def __call__(self, other): def __call__(self, other):
res = EndomorphicOperator.__call__(self, other) res = EndomorphicOperator.__call__(self, other)
if np.isreal(self._factor) and self._factor >= 0: if np.isreal(self._factor) and self._factor >= 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment