Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
6cc94a9e
Commit
6cc94a9e
authored
Nov 27, 2019
by
Philipp Arras
Browse files
Export useful assert function
parent
14052dd3
Pipeline
#64451
passed with stages
in 9 minutes and 35 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty5/extra.py
View file @
6cc94a9e
...
...
@@ -26,15 +26,16 @@ from .multi_field import MultiField
from
.operators.linear_operator
import
LinearOperator
from
.sugar
import
from_random
__all__
=
[
"consistency_check"
,
"check_jacobian_consistency"
]
__all__
=
[
"consistency_check"
,
"check_jacobian_consistency"
,
"assert_allclose"
]
def
_
assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
def
assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
if
isinstance
(
f1
,
Field
):
return
np
.
testing
.
assert_allclose
(
f1
.
local_data
,
f2
.
local_data
,
atol
=
atol
,
rtol
=
rtol
)
for
key
,
val
in
f1
.
items
():
_
assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
assert_allclose
(
val
,
f2
[
key
],
atol
=
atol
,
rtol
=
rtol
)
def
_adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
,
...
...
@@ -57,11 +58,11 @@ def _inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
return
foo
=
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
_
assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
_
assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
assert_allclose
(
res
,
foo
,
atol
=
atol
,
rtol
=
rtol
)
def
_full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
,
...
...
@@ -80,7 +81,7 @@ def _check_linearity(op, domain_dtype, atol, rtol):
alpha
=
np
.
random
.
random
()
# FIXME: this can break badly with MPI!
val1
=
op
(
alpha
*
fld1
+
fld2
)
val2
=
alpha
*
op
(
fld1
)
+
op
(
fld2
)
_
assert_allclose
(
val1
,
val2
,
atol
=
atol
,
rtol
=
rtol
)
assert_allclose
(
val1
,
val2
,
atol
=
atol
,
rtol
=
rtol
)
def
_actual_domain_check
(
op
,
domain_dtype
=
None
,
inp
=
None
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment