Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
d799dd37
Commit
d799dd37
authored
Jun 17, 2020
by
Reimar Leike
Browse files
Merge branch 'fisher_tests' into fix_complex_samples
parents
506ba68d
b703ab22
Changes
1
Hide whitespace changes
Inline
Side-by-side
test/test_operators/test_fisher_metric.py
0 → 100644
View file @
d799dd37
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
import
pytest
import
nifty6
as
ift
from
..common
import
list2fixture
,
setup_function
,
teardown_function
spaces
=
[
ift
.
GLSpace
(
5
),
ift
.
MultiDomain
.
make
({
''
:
ift
.
RGSpace
(
5
,
distances
=
.
789
)}),
(
ift
.
RGSpace
(
3
,
distances
=
.
789
),
ift
.
UnstructuredDomain
(
2
))]
pmp
=
pytest
.
mark
.
parametrize
field
=
list2fixture
([
ift
.
from_random
(
sp
,
'normal'
)
for
sp
in
spaces
]
+
[
ift
.
from_random
(
sp
,
'normal'
,
dtype
=
np
.
complex128
)
for
sp
in
spaces
])
Nsamp
=
2000
np
.
random
.
seed
(
42
)
def
_to_array
(
d
):
if
isinstance
(
d
,
np
.
ndarray
):
return
d
assert
isinstance
(
d
,
dict
)
return
np
.
concatenate
(
list
(
d
.
values
()))
def
energy_tester
(
pos
,
get_noisy_data
,
energy_initializer
):
domain
=
pos
.
domain
test_vec
=
ift
.
from_random
(
domain
,
'normal'
)
results
=
[]
lin
=
ift
.
Linearization
.
make_var
(
pos
)
for
i
in
range
(
Nsamp
):
data
=
get_noisy_data
(
pos
)
energy
=
energy_initializer
(
data
)
grad
=
energy
(
lin
).
jac
.
adjoint
(
ift
.
full
(
energy
.
target
,
1.
))
results
.
append
(
_to_array
((
grad
*
grad
.
s_vdot
(
test_vec
)).
val
))
res
=
np
.
mean
(
np
.
array
(
results
),
axis
=
0
)
std
=
np
.
std
(
np
.
array
(
results
),
axis
=
0
)
/
np
.
sqrt
(
Nsamp
)
energy
=
energy_initializer
(
data
)
lin
=
ift
.
Linearization
.
make_var
(
pos
,
want_metric
=
True
)
res2
=
_to_array
(
energy
(
lin
).
metric
(
test_vec
).
val
)
np
.
testing
.
assert_allclose
(
res
/
std
,
res2
/
std
,
atol
=
6
)
def
test_GaussianEnergy
(
field
):
dtype
=
field
.
dtype
icov
=
ift
.
from_random
(
field
.
domain
,
'normal'
)
**
2
icov
=
ift
.
makeOp
(
icov
)
get_noisy_data
=
lambda
mean
:
mean
+
icov
.
draw_sample_with_dtype
(
from_inverse
=
True
,
dtype
=
dtype
)
E_init
=
lambda
mean
:
ift
.
GaussianEnergy
(
mean
=
mean
,
inverse_covariance
=
icov
)
energy_tester
(
field
,
get_noisy_data
,
E_init
)
def
test_PoissonEnergy
(
field
):
if
not
isinstance
(
field
,
ift
.
Field
):
return
if
np
.
iscomplexobj
(
field
.
val
):
return
def
get_noisy_data
(
mean
):
return
ift
.
makeField
(
mean
.
domain
,
np
.
random
.
poisson
(
mean
.
val
))
lam
=
10
*
(
field
**
2
).
clip
(
0.1
,
None
)
# make rate positive and high enough to avoid bad statistic
E_init
=
lambda
mean
:
ift
.
PoissonianEnergy
(
mean
)
energy_tester
(
lam
,
get_noisy_data
,
E_init
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment