Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
3aeba77e
Commit
3aeba77e
authored
Jun 22, 2020
by
Philipp Arras
Browse files
Support operators with very big domains improvements
parent
8da1276e
Pipeline
#77076
passed with stages
in 13 minutes and 19 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
3aeba77e
...
...
@@ -87,8 +87,7 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
def
check_operator
(
op
,
loc
,
tol
=
1e-12
,
ntries
=
100
,
perf_check
=
True
,
only_r_differentiable
=
True
,
metric_sampling
=
True
,
max_combinations
=
np
.
inf
):
only_r_differentiable
=
True
,
metric_sampling
=
True
):
"""
Performs various checks of the implementation of linear and nonlinear
operators.
...
...
@@ -113,9 +112,6 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
metric_sampling: Boolean
If op is an EnergyOperator, metric_sampling determines whether the
test shall try to sample from the metric or not.
max_combinations : Integer
The maximum number of combinations of keys which shall be used for
checking partially constant operator and its derivative.
"""
if
not
isinstance
(
op
,
Operator
):
raise
TypeError
(
'This test tests only linear operators.'
)
...
...
@@ -125,7 +121,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
_jac_vs_finite_differences
(
op
,
loc
,
np
.
sqrt
(
tol
),
ntries
,
only_r_differentiable
)
_check_nontrivial_constant
(
op
,
loc
,
tol
,
ntries
,
only_r_differentiable
,
metric_sampling
,
max_combinations
)
metric_sampling
)
def
assert_allclose
(
f1
,
f2
,
atol
,
rtol
):
...
...
@@ -318,22 +314,19 @@ def _linearization_value_consistency(op, loc):
def
_check_nontrivial_constant
(
op
,
loc
,
tol
,
ntries
,
only_r_differentiable
,
metric_sampling
,
max_combinations
=
np
.
inf
):
metric_sampling
):
if
isinstance
(
op
.
domain
,
DomainTuple
):
return
keys
=
op
.
domain
.
keys
()
combis
=
[]
if
len
(
keys
)
>
15
:
if
len
(
keys
)
>
4
:
from
.logger
import
logger
logger
.
warning
(
'Operator domain has more than
15
keys.'
)
logger
.
warning
(
'Operator domain has more than
4
keys.'
)
logger
.
warning
(
'Check derivatives only with one constant key at a time.'
)
combis
=
list
(
keys
)
combis
=
[[
kk
]
for
kk
in
keys
]
else
:
for
ll
in
range
(
1
,
len
(
keys
)):
combis
.
extend
(
list
(
combinations
(
keys
,
ll
)))
if
len
(
combis
)
>
max_combinations
:
combis
=
random
.
current_rng
().
choice
(
combis
,
int
(
max_combinations
),
replace
=
False
)
for
cstkeys
in
combis
:
varkeys
=
set
(
keys
)
-
set
(
cstkeys
)
cstloc
=
loc
.
extract_by_keys
(
cstkeys
)
...
...
test/test_operators/test_correlated_fields.py
View file @
3aeba77e
...
...
@@ -108,7 +108,5 @@ def testAmplitudesInvariants(sspace, N):
assert_
(
op
.
target
[
-
1
]
==
fsspace
)
for
ampl
in
fa
.
normalized_amplitudes
:
ift
.
extra
.
check_operator
(
ampl
,
0.1
*
ift
.
from_random
(
ampl
.
domain
),
ntries
=
10
,
max_combinations
=
3
)
ift
.
extra
.
check_operator
(
op
,
0.1
*
ift
.
from_random
(
op
.
domain
),
ntries
=
10
,
max_combinations
=
5
)
ift
.
extra
.
check_operator
(
ampl
,
0.1
*
ift
.
from_random
(
ampl
.
domain
),
ntries
=
10
)
ift
.
extra
.
check_operator
(
op
,
0.1
*
ift
.
from_random
(
op
.
domain
),
ntries
=
10
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment