Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
db698ac1
Commit
db698ac1
authored
Jun 08, 2021
by
Philipp Arras
Browse files
Check operators for purity
parent
11bc2c37
Pipeline
#103100
passed with stages
in 14 minutes and 7 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
db698ac1
...
@@ -73,6 +73,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
...
@@ -73,6 +73,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
_domain_check_linear
(
op
.
adjoint
,
target_dtype
)
_domain_check_linear
(
op
.
adjoint
,
target_dtype
)
_domain_check_linear
(
op
.
inverse
,
target_dtype
)
_domain_check_linear
(
op
.
inverse
,
target_dtype
)
_domain_check_linear
(
op
.
adjoint
.
inverse
,
domain_dtype
)
_domain_check_linear
(
op
.
adjoint
.
inverse
,
domain_dtype
)
_purity_check
(
op
,
from_random
(
op
.
domain
,
dtype
=
domain_dtype
))
_purity_check
(
op
.
adjoint
.
inverse
,
from_random
(
op
.
domain
,
dtype
=
domain_dtype
))
_purity_check
(
op
.
adjoint
,
from_random
(
op
.
target
,
dtype
=
target_dtype
))
_purity_check
(
op
.
inverse
,
from_random
(
op
.
target
,
dtype
=
target_dtype
))
_check_linearity
(
op
,
domain_dtype
,
atol
,
rtol
)
_check_linearity
(
op
,
domain_dtype
,
atol
,
rtol
)
_check_linearity
(
op
.
adjoint
,
target_dtype
,
atol
,
rtol
)
_check_linearity
(
op
.
adjoint
,
target_dtype
,
atol
,
rtol
)
_check_linearity
(
op
.
inverse
,
target_dtype
,
atol
,
rtol
)
_check_linearity
(
op
.
inverse
,
target_dtype
,
atol
,
rtol
)
...
@@ -120,6 +124,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
...
@@ -120,6 +124,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
if
not
isinstance
(
op
,
Operator
):
if
not
isinstance
(
op
,
Operator
):
raise
TypeError
(
'This test tests only (nonlinear) operators.'
)
raise
TypeError
(
'This test tests only (nonlinear) operators.'
)
_domain_check_nonlinear
(
op
,
loc
)
_domain_check_nonlinear
(
op
,
loc
)
_purity_check
(
op
,
loc
)
_performance_check
(
op
,
loc
,
bool
(
perf_check
))
_performance_check
(
op
,
loc
,
bool
(
perf_check
))
_linearization_value_consistency
(
op
,
loc
)
_linearization_value_consistency
(
op
,
loc
)
_jac_vs_finite_differences
(
op
,
loc
,
np
.
sqrt
(
tol
),
ntries
,
_jac_vs_finite_differences
(
op
,
loc
,
np
.
sqrt
(
tol
),
ntries
,
...
@@ -288,6 +293,14 @@ def _performance_check(op, pos, raise_on_fail):
...
@@ -288,6 +293,14 @@ def _performance_check(op, pos, raise_on_fail):
raise
RuntimeError
(
s
)
raise
RuntimeError
(
s
)
def
_purity_check
(
op
,
pos
):
if
isinstance
(
op
,
LinearOperator
)
and
(
op
.
capability
&
op
.
TIMES
)
!=
op
.
TIMES
:
return
res0
=
op
(
pos
)
res1
=
op
(
pos
)
assert_equal
(
res0
,
res1
)
def
_get_acceptable_location
(
op
,
loc
,
lin
):
def
_get_acceptable_location
(
op
,
loc
,
lin
):
if
not
np
.
isfinite
(
lin
.
val
.
s_sum
()):
if
not
np
.
isfinite
(
lin
.
val
.
s_sum
()):
raise
ValueError
(
'Initial value must be finite'
)
raise
ValueError
(
'Initial value must be finite'
)
...
...
test/test_extra.py
0 → 100644
View file @
db698ac1
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import
numpy
as
np
import
pytest
import
nifty7
as
ift
from
time
import
time
from
.common
import
list2fixture
,
setup_function
,
teardown_function
pmp
=
pytest
.
mark
.
parametrize
class
NonPureOperator
(
ift
.
Operator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
self
.
_target
=
ift
.
makeDomain
(
domain
)
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
return
x
*
time
()
class
NonPureLinearOperator
(
ift
.
LinearOperator
):
def
__init__
(
self
,
domain
,
cap
):
self
.
_domain
=
self
.
_target
=
ift
.
makeDomain
(
domain
)
self
.
_capability
=
cap
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
return
x
*
time
()
@
pmp
(
"cap"
,
[
ift
.
LinearOperator
.
ADJOINT_TIMES
,
ift
.
LinearOperator
.
INVERSE_TIMES
|
ift
.
LinearOperator
.
TIMES
])
@
pmp
(
"ddtype"
,
[
np
.
float64
,
np
.
complex128
])
@
pmp
(
"tdtype"
,
[
np
.
float64
,
np
.
complex128
])
def
test_purity_check_linear
(
cap
,
ddtype
,
tdtype
):
dom
=
ift
.
RGSpace
(
2
)
op
=
NonPureLinearOperator
(
dom
,
cap
)
with
pytest
.
raises
(
AssertionError
):
ift
.
extra
.
check_linear_operator
(
op
,
ddtype
,
tdtype
)
@
pmp
(
"dtype"
,
[
np
.
float64
,
np
.
complex128
])
def
test_purity_check
(
dtype
):
dom
=
ift
.
RGSpace
(
2
)
op
=
NonPureOperator
(
dom
)
with
pytest
.
raises
(
AssertionError
):
ift
.
extra
.
check_operator
(
op
,
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment