Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
90bcde2c
Commit
90bcde2c
authored
Jun 19, 2020
by
Reimar Leike
Browse files
Add test for VariableCovarianceGaussianEnergy
parent
06d728fc
Pipeline
#76989
passed with stages
in 12 minutes and 37 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
test/test_operators/test_fisher_metric.py
View file @
90bcde2c
...
...
@@ -29,6 +29,8 @@ pmp = pytest.mark.parametrize
field
=
list2fixture
([
ift
.
from_random
(
sp
,
'normal'
)
for
sp
in
spaces
]
+
[
ift
.
from_random
(
sp
,
'normal'
,
dtype
=
np
.
complex128
)
for
sp
in
spaces
])
dtype
=
list2fixture
([
np
.
float64
,
np
.
complex128
])
Nsamp
=
2000
np
.
random
.
seed
(
42
)
...
...
@@ -110,7 +112,7 @@ def test_GaussianEnergy(field):
icov
=
ift
.
makeOp
(
icov
)
get_noisy_data
=
lambda
mean
:
mean
+
icov
.
draw_sample_with_dtype
(
from_inverse
=
True
,
dtype
=
dtype
)
E_init
=
lambda
mean
:
ift
.
GaussianEnergy
(
mean
=
mean
,
inverse_covariance
=
icov
)
E_init
=
lambda
data
:
ift
.
GaussianEnergy
(
mean
=
data
,
inverse_covariance
=
icov
)
energy_tester
(
field
,
get_noisy_data
,
E_init
)
...
...
@@ -122,5 +124,20 @@ def test_PoissonEnergy(field):
get_noisy_data
=
lambda
mean
:
ift
.
makeField
(
mean
.
domain
,
np
.
random
.
poisson
(
mean
.
val
))
# Make rate positive and high enough to avoid bad statistic
lam
=
10
*
(
field
**
2
).
clip
(
0.1
,
None
)
E_init
=
lambda
mean
:
ift
.
PoissonianEnergy
(
mean
)
E_init
=
lambda
data
:
ift
.
PoissonianEnergy
(
data
)
energy_tester
(
lam
,
get_noisy_data
,
E_init
)
def
test_VariableCovarianceGaussianEnergy
(
dtype
):
dom
=
ift
.
UnstructuredDomain
(
3
)
res
=
ift
.
from_random
(
dom
,
'normal'
,
dtype
=
dtype
)
ivar
=
ift
.
from_random
(
dom
,
'normal'
)
**
2
+
4.
mf
=
ift
.
MultiField
.
from_dict
({
'res'
:
res
,
'ivar'
:
ivar
})
energy
=
ift
.
VariableCovarianceGaussianEnergy
(
dom
,
'res'
,
'ivar'
,
dtype
)
def
get_noisy_data
(
mean
):
samp
=
ift
.
from_random
(
dom
,
'normal'
,
dtype
)
samp
=
samp
/
mean
[
'ivar'
].
sqrt
()
return
samp
+
mean
[
'res'
]
def
E_init
(
data
):
adder
=
ift
.
Adder
(
ift
.
MultiField
.
from_dict
({
'res'
:
data
}),
neg
=
True
)
return
energy
.
partial_insert
(
adder
)
energy_tester
(
mf
,
get_noisy_data
,
E_init
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment