Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
9a4d6183
Commit
9a4d6183
authored
Jul 20, 2018
by
Silvan Streit
Browse files
Test models: WIP consistency check
* fails for constant models
parent
bfd57f63
Changes
1
Hide whitespace changes
Inline
Side-by-side
test/test_models/test_model_gradients.py
View file @
9a4d6183
...
...
@@ -25,15 +25,70 @@ import numpy as np
class
Model_Tests
(
unittest
.
TestCase
):
@
expand
(
product
([
ift
.
GLSpace
(
15
),
ift
.
RGSpace
(
64
,
distances
=
.
789
),
ift
.
RGSpace
([
32
,
32
],
distances
=
.
789
)],
[
4
,
78
,
23
]))
def
testMul
(
self
,
space
,
seed
):
np
.
random
.
seed
(
seed
)
S
=
ift
.
ScalingOperator
(
1.
,
space
)
s1
=
S
.
draw_sample
()
s2
=
S
.
draw_sample
()
s1_var
=
ift
.
Variable
(
ift
.
MultiField
.
from_dict
({
's1'
:
s1
}))[
's1'
]
s2_var
=
ift
.
Variable
(
ift
.
MultiField
.
from_dict
({
's2'
:
s2
}))[
's2'
]
ift
.
extra
.
check_value_gradient_consistency
(
s1_var
*
s2_var
)
def
make_model
(
self
,
type
,
**
kwargs
):
if
type
==
'Constant'
:
np
.
random
.
seed
(
kwargs
[
'seed'
])
S
=
ift
.
ScalingOperator
(
1.
,
kwargs
[
'space'
])
s
=
S
.
draw_sample
()
return
ift
.
Constant
(
ift
.
MultiField
.
from_dict
({
kwargs
[
'space_key'
]:
s
}),
ift
.
MultiField
.
from_dict
({
kwargs
[
'space_key'
]:
s
}))
elif
type
==
'Variable'
:
np
.
random
.
seed
(
kwargs
[
'seed'
])
S
=
ift
.
ScalingOperator
(
1.
,
kwargs
[
'space'
])
s
=
S
.
draw_sample
()
return
ift
.
Variable
(
ift
.
MultiField
.
from_dict
({
kwargs
[
'space_key'
]:
s
}))
elif
type
==
'LinearModel'
:
return
ift
.
LinearModel
(
inp
=
kwargs
[
'model'
],
lin_op
=
kwargs
[
'lin_op'
])
else
:
raise
ValueError
(
'unknown type passed'
)
def
make_linear_operator
(
self
,
type
,
**
kwargs
):
if
type
==
'ScalingOperator'
:
lin_op
=
ift
.
ScalingOperator
(
1.
,
kwargs
[
'space'
])
else
:
raise
ValueError
(
'unknown type passed'
)
return
lin_op
@
expand
(
product
(
[
'Variable'
,
'Constant'
],
[
ift
.
GLSpace
(
15
),
ift
.
RGSpace
(
64
,
distances
=
.
789
),
ift
.
RGSpace
([
32
,
32
],
distances
=
.
789
)],
[
4
,
78
,
23
]
))
def
testBasics
(
self
,
type1
,
space
,
seed
):
model1
=
self
.
make_model
(
type1
,
space_key
=
's1'
,
space
=
space
,
seed
=
seed
)[
's1'
]
ift
.
extra
.
check_value_gradient_consistency
(
model1
)
@
expand
(
product
(
[
'Variable'
,
'Constant'
],
[
'Variable'
],
[
ift
.
GLSpace
(
15
),
ift
.
RGSpace
(
64
,
distances
=
.
789
),
ift
.
RGSpace
([
32
,
32
],
distances
=
.
789
)],
[
4
,
78
,
23
]
))
def
testMul
(
self
,
type1
,
type2
,
space
,
seed
):
model1
=
self
.
make_model
(
type1
,
space_key
=
's1'
,
space
=
space
,
seed
=
seed
)[
's1'
]
model2
=
self
.
make_model
(
type2
,
space_key
=
's2'
,
space
=
space
,
seed
=
seed
+
1
)[
's2'
]
ift
.
extra
.
check_value_gradient_consistency
(
model1
*
model2
)
@
expand
(
product
(
[
'Variable'
,
'Constant'
],
[
ift
.
GLSpace
(
15
),
ift
.
RGSpace
(
64
,
distances
=
.
789
),
ift
.
RGSpace
([
32
,
32
],
distances
=
.
789
)],
[
4
,
78
,
23
]
))
def
testLinModel
(
self
,
type1
,
space
,
seed
):
model1
=
self
.
make_model
(
type1
,
space_key
=
's1'
,
space
=
space
,
seed
=
seed
)[
's1'
]
lin_op
=
self
.
make_linear_operator
(
'ScalingOperator'
,
space
=
space
)
model2
=
self
.
make_model
(
'LinearModel'
,
model
=
model1
,
lin_op
=
lin_op
)
ift
.
extra
.
check_value_gradient_consistency
(
model1
*
model2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment