Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
ecd73d29
Commit
ecd73d29
authored
Nov 28, 2020
by
Philipp Arras
Browse files
Add get_sqrt() to the most important operators
parent
f4703ca5
Changes
6
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
ecd73d29
...
...
@@ -85,6 +85,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
only_r_linear
)
_full_implementation
(
op
.
adjoint
.
inverse
,
domain_dtype
,
target_dtype
,
atol
,
rtol
,
only_r_linear
)
_check_sqrt
(
op
,
domain_dtype
)
_check_sqrt
(
op
.
adjoint
,
target_dtype
)
_check_sqrt
(
op
.
inverse
,
target_dtype
)
_check_sqrt
(
op
.
adjoint
.
inverse
,
domain_dtype
)
def
check_operator
(
op
,
loc
,
tol
=
1e-12
,
ntries
=
100
,
perf_check
=
True
,
...
...
@@ -197,6 +201,23 @@ def _domain_check_linear(op, domain_dtype=None, inp=None):
myassert
(
op
(
inp
).
domain
is
op
.
target
)
def
_check_sqrt
(
op
,
domain_dtype
):
if
not
is_endo
(
op
):
try
:
op
.
get_sqrt
()
raise
RuntimeError
(
"Operator implements get_sqrt() although it is not an endomorphic operator."
)
except
AttributeError
:
return
try
:
sqop
=
op
.
get_sqrt
()
except
(
NotImplementedError
,
AttributeError
):
return
fld
=
from_random
(
op
.
domain
,
dtype
=
domain_dtype
)
a
=
op
(
fld
)
b
=
(
sqop
.
adjoint
@
sqop
)(
fld
)
return
assert_allclose
(
a
,
b
,
rtol
=
1e-15
)
def
_domain_check_nonlinear
(
op
,
loc
):
_domain_check
(
op
)
myassert
(
isinstance
(
loc
,
(
Field
,
MultiField
)))
...
...
src/operators/block_diagonal_operator.py
View file @
ecd73d29
...
...
@@ -48,7 +48,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
raise
TypeError
(
"LinearOperator expected"
)
def
get_sqrt
(
self
):
ops
=
{
kk
:
vv
.
sqrt
()
for
kk
,
vv
in
self
.
_ops
.
items
()
if
vv
is
not
None
}
ops
=
{}
for
ii
,
kk
in
enumerate
(
self
.
_domain
.
keys
()):
if
self
.
_ops
[
ii
]
is
None
:
continue
try
:
ops
[
kk
]
=
self
.
_ops
[
ii
].
get_sqrt
()
except
AttributeError
:
raise
NotImplementedError
return
BlockDiagonalOperator
(
self
.
_domain
,
ops
)
def
apply
(
self
,
x
,
mode
):
...
...
src/operators/diagonal_operator.py
View file @
ecd73d29
...
...
@@ -166,5 +166,10 @@ class DiagonalOperator(EndomorphicOperator):
res
=
Field
.
from_random
(
domain
=
self
.
_domain
,
random_type
=
"normal"
,
dtype
=
dtype
)
return
self
.
process_sample
(
res
,
from_inverse
)
def
get_sqrt
(
self
):
if
not
np
.
iscomplexobj
(
self
.
_ldiag
)
or
(
self
.
_ldiag
<
0
).
any
():
raise
NotImplementedError
return
self
.
_from_ldiag
(
None
,
np
.
sqrt
(
self
.
_ldiag
))
def
__repr__
(
self
):
return
"DiagonalOperator"
src/operators/endomorphic_operator.py
View file @
ecd73d29
...
...
@@ -75,6 +75,20 @@ class EndomorphicOperator(LinearOperator):
"""
raise
NotImplementedError
def
get_sqrt
(
self
):
"""Return operator op which obeys `self == op.adjoint @ op`.
Note that this function is only implemented for operators with real
spectrum.
Returns
-------
EndomorphicOperator
Operator which is the square root of `self`
"""
raise
NotImplementedError
def
_dom
(
self
,
mode
):
return
self
.
_domain
...
...
src/operators/sandwich_operator.py
View file @
ecd73d29
...
...
@@ -93,6 +93,11 @@ class SandwichOperator(EndomorphicOperator):
return
self
.
_bun
.
adjoint_times
(
self
.
_cheese
.
draw_sample
(
from_inverse
))
def
get_sqrt
(
self
):
if
self
.
_cheese
is
None
:
return
self
.
_bun
return
self
.
_cheese
.
get_sqrt
()
@
self
.
_bun
def
__repr__
(
self
):
from
..utilities
import
indent
return
"
\n
"
.
join
((
...
...
src/operators/scaling_operator.py
View file @
ecd73d29
...
...
@@ -95,6 +95,12 @@ class ScalingOperator(EndomorphicOperator):
from
..sugar
import
from_random
return
from_random
(
domain
=
self
.
_domain
,
random_type
=
"normal"
,
dtype
=
dtype
,
std
=
self
.
_get_fct
(
from_inverse
))
def
get_sqrt
(
self
):
fct
=
self
.
_get_fct
(
False
)
if
np
.
iscomplexobj
(
fct
)
or
fct
<
0
:
raise
NotImplementedError
return
ScalingOperator
(
self
.
_domain
,
fct
)
def
__call__
(
self
,
other
):
res
=
EndomorphicOperator
.
__call__
(
self
,
other
)
if
np
.
isreal
(
self
.
_factor
)
and
self
.
_factor
>=
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment