Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
3490e9ce
Commit
3490e9ce
authored
Nov 06, 2019
by
Martin Reinecke
Browse files
Merge branch 'more_operator_checks' into 'NIFTy_5'
Add more automatic checks for operators See merge request
!368
parents
046d074c
5defc35b
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty5/extra.py
View file @
3490e9ce
...
...
@@ -16,11 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
from
numpy.testing
import
assert_
from
.domain_tuple
import
DomainTuple
from
.field
import
Field
from
.linearization
import
Linearization
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.operators.linear_operator
import
LinearOperator
from
.sugar
import
from_random
...
...
@@ -81,6 +83,38 @@ def _check_linearity(op, domain_dtype, atol, rtol):
_assert_allclose
(
val1
,
val2
,
atol
=
atol
,
rtol
=
rtol
)
def
_actual_domain_check
(
op
,
domain_dtype
=
None
,
inp
=
None
):
needed_cap
=
op
.
TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
if
domain_dtype
is
not
None
:
inp
=
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
elif
inp
is
None
:
raise
ValueError
(
'Need to specify either dtype or inp'
)
assert_
(
inp
.
domain
is
op
.
domain
)
assert_
(
op
(
inp
).
domain
is
op
.
target
)
def
_actual_domain_check_nonlinear
(
op
,
loc
,
target_dtype
=
np
.
float64
):
assert
isinstance
(
loc
,
(
Field
,
MultiField
))
assert_
(
loc
.
domain
is
op
.
domain
)
lin
=
Linearization
.
make_var
(
loc
,
False
)
reslin
=
op
(
lin
)
assert_
(
lin
.
domain
is
op
.
domain
)
assert_
(
lin
.
target
is
op
.
domain
)
assert_
(
lin
.
val
.
domain
is
lin
.
domain
)
assert_
(
reslin
.
domain
is
op
.
domain
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
val
.
domain
is
reslin
.
target
)
assert_
(
reslin
.
target
is
op
.
target
)
assert_
(
reslin
.
jac
.
domain
is
reslin
.
domain
)
assert_
(
reslin
.
jac
.
target
is
reslin
.
target
)
_actual_domain_check
(
reslin
.
jac
,
inp
=
loc
)
_actual_domain_check
(
reslin
.
jac
.
adjoint
,
domain_dtype
=
target_dtype
)
def
_domain_check
(
op
):
for
dd
in
[
op
.
domain
,
op
.
target
]:
if
not
isinstance
(
dd
,
(
DomainTuple
,
MultiDomain
)):
...
...
@@ -123,6 +157,10 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
if
not
isinstance
(
op
,
LinearOperator
):
raise
TypeError
(
'This test tests only linear operators.'
)
_domain_check
(
op
)
_actual_domain_check
(
op
,
domain_dtype
)
_actual_domain_check
(
op
.
adjoint
,
target_dtype
)
_actual_domain_check
(
op
.
inverse
,
target_dtype
)
_actual_domain_check
(
op
.
adjoint
.
inverse
,
domain_dtype
)
_check_linearity
(
op
,
domain_dtype
,
atol
,
rtol
)
_check_linearity
(
op
.
adjoint
,
target_dtype
,
atol
,
rtol
)
_check_linearity
(
op
.
inverse
,
target_dtype
,
atol
,
rtol
)
...
...
@@ -180,6 +218,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100):
Tolerance for the check.
"""
_domain_check
(
op
)
_actual_domain_check_nonlinear
(
op
,
loc
)
for
_
in
range
(
ntries
):
lin
=
op
(
Linearization
.
make_var
(
loc
))
loc2
,
lin2
=
_get_acceptable_location
(
op
,
loc
,
lin
)
...
...
nifty5/operators/simple_linear_operators.py
View file @
3490e9ce
...
...
@@ -187,9 +187,7 @@ class _SlowFieldAdapter(LinearOperator):
self
.
_check_input
(
x
,
mode
)
if
isinstance
(
x
,
MultiField
):
return
x
[
self
.
_name
]
else
:
return
MultiField
.
from_dict
({
self
.
_name
:
x
},
domain
=
self
.
_tgt
(
mode
))
return
MultiField
.
from_dict
({
self
.
_name
:
x
},
domain
=
self
.
_tgt
(
mode
))
def
__repr__
(
self
):
return
'_SlowFieldAdapter'
...
...
@@ -338,12 +336,17 @@ class PartialExtractor(LinearOperator):
if
self
.
_domain
[
key
]
is
not
self
.
_target
[
key
]:
raise
ValueError
(
"domain mismatch"
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_compldomain
=
MultiDomain
.
make
({
kk
:
self
.
_domain
[
kk
]
for
kk
in
self
.
_domain
.
keys
()
if
kk
not
in
self
.
_target
.
keys
()})
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
x
.
extract
(
self
.
_target
)
return
MultiField
.
from_dict
({
key
:
x
[
key
]
for
key
in
x
.
domain
.
keys
()})
res0
=
MultiField
.
from_dict
({
key
:
x
[
key
]
for
key
in
x
.
domain
.
keys
()})
res1
=
MultiField
.
full
(
self
.
_compldomain
,
0.
)
return
res0
.
unite
(
res1
)
class
MatrixProductOperator
(
EndomorphicOperator
):
...
...
@@ -359,20 +362,19 @@ class MatrixProductOperator(EndomorphicOperator):
`dot()` and `transpose()` in the style of numpy arrays.
"""
def
__init__
(
self
,
domain
,
matrix
):
self
.
_domain
=
domain
self
.
_domain
=
DomainTuple
.
make
(
domain
)
shp
=
self
.
_domain
.
shape
if
len
(
shp
)
>
1
:
raise
TypeError
(
'Only 1D-domain supported yet.'
)
if
matrix
.
shape
!=
(
*
shp
,
*
shp
):
raise
ValueError
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_mat
=
matrix
self
.
_mat_tr
=
matrix
.
transpose
()
self
.
_mat_tr
=
matrix
.
transpose
()
.
conjugate
()
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
res
=
x
.
to_global_data
()
if
mode
==
self
.
TIMES
:
res
=
self
.
_mat
.
dot
(
res
)
if
mode
==
self
.
ADJOINT_TIMES
:
res
=
self
.
_mat_tr
.
dot
(
res
)
f
=
self
.
_mat
.
dot
if
mode
==
self
.
TIMES
else
self
.
_mat_tr
.
dot
res
=
f
(
res
)
return
Field
.
from_global_data
(
self
.
_domain
,
res
)
def
__repr__
(
self
):
return
"MatrixProductOperator"
test/test_energy_gradients.py
View file @
3490e9ce
...
...
@@ -53,7 +53,7 @@ def test_inverse_gamma(field):
d
=
np
.
random
.
normal
(
10
,
size
=
space
.
shape
)
**
2
d
=
ift
.
Field
.
from_global_data
(
space
,
d
)
energy
=
ift
.
InverseGammaLikelihood
(
d
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
,
tol
=
1e-
7
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
,
tol
=
1e-
5
)
def
testPoissonian
(
field
):
...
...
@@ -83,4 +83,4 @@ def test_bernoulli(field):
d
=
np
.
random
.
binomial
(
1
,
0.1
,
size
=
space
.
shape
)
d
=
ift
.
Field
.
from_global_data
(
space
,
d
)
energy
=
ift
.
BernoulliEnergy
(
d
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
,
tol
=
1e-
6
)
ift
.
extra
.
check_jacobian_consistency
(
energy
,
field
,
tol
=
1e-
5
)
test/test_operators/test_adjoint.py
View file @
3490e9ce
...
...
@@ -295,3 +295,34 @@ def testValueInserter(sp, seed):
ind
.
append
(
np
.
random
.
randint
(
0
,
ss
-
1
))
op
=
ift
.
ValueInserter
(
sp
,
ind
)
ift
.
extra
.
consistency_check
(
op
)
@
pmp
(
'sp'
,
[
ift
.
RGSpace
(
10
)])
@
pmp
(
'seed'
,
[
12
,
3
])
def
testMatrixProductOperator
(
sp
,
seed
):
np
.
random
.
seed
(
seed
)
mat
=
np
.
random
.
randn
(
*
sp
.
shape
,
*
sp
.
shape
)
op
=
ift
.
MatrixProductOperator
(
sp
,
mat
)
ift
.
extra
.
consistency_check
(
op
)
mat
=
mat
+
1j
*
np
.
random
.
randn
(
*
sp
.
shape
,
*
sp
.
shape
)
op
=
ift
.
MatrixProductOperator
(
sp
,
mat
)
ift
.
extra
.
consistency_check
(
op
)
@
pmp
(
'seed'
,
[
12
,
3
])
def
testPartialExtractor
(
seed
):
np
.
random
.
seed
(
seed
)
tgt
=
{
'a'
:
ift
.
RGSpace
(
1
),
'b'
:
ift
.
RGSpace
(
2
)}
dom
=
tgt
.
copy
()
dom
[
'c'
]
=
ift
.
RGSpace
(
3
)
dom
=
ift
.
MultiDomain
.
make
(
dom
)
tgt
=
ift
.
MultiDomain
.
make
(
tgt
)
op
=
ift
.
PartialExtractor
(
dom
,
tgt
)
ift
.
extra
.
consistency_check
(
op
)
@
pmp
(
'seed'
,
[
12
,
3
])
def
testSlowFieldAdapter
(
seed
):
dom
=
{
'a'
:
ift
.
RGSpace
(
1
),
'b'
:
ift
.
RGSpace
(
2
)}
op
=
ift
.
operators
.
simple_linear_operators
.
_SlowFieldAdapter
(
dom
,
'a'
)
ift
.
extra
.
consistency_check
(
op
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment