Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
bb1069ae
Commit
bb1069ae
authored
Jan 21, 2020
by
Philipp Arras
Browse files
Expand lh_sampling_dtype to MPI-KL, fixup
parent
96d52905
Pipeline
#67485
passed with stages
in 25 minutes and 34 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/minimization/metric_gaussian_kl.py
View file @
bb1069ae
...
...
@@ -122,11 +122,13 @@ class MetricGaussianKL(Energy):
self
.
_grad
=
g
*
(
1.
/
len
(
self
.
_samples
))
self
.
_metric
=
None
self
.
_napprox
=
napprox
self
.
_sampdt
=
lh_sampling_dtype
def
at
(
self
,
position
):
return
MetricGaussianKL
(
position
,
self
.
_hamiltonian
,
0
,
self
.
_constants
,
self
.
_point_estimates
,
napprox
=
self
.
_napprox
,
_samples
=
self
.
_samples
)
napprox
=
self
.
_napprox
,
_samples
=
self
.
_samples
,
lh_sampling_dtype
=
self
.
_sampdt
)
@
property
def
value
(
self
):
...
...
nifty6/minimization/metric_gaussian_kl_mpi.py
View file @
bb1069ae
...
...
@@ -132,7 +132,8 @@ class MetricGaussianKL_MPI(Energy):
def
__init__
(
self
,
mean
,
hamiltonian
,
n_samples
,
constants
=
[],
point_estimates
=
[],
mirror_samples
=
False
,
napprox
=
0
,
_samples
=
None
,
seed_offset
=
0
):
napprox
=
0
,
_samples
=
None
,
seed_offset
=
0
,
lh_sampling_dtype
=
np
.
float64
):
super
(
MetricGaussianKL_MPI
,
self
).
__init__
(
mean
)
if
not
isinstance
(
hamiltonian
,
StandardHamiltonian
):
...
...
@@ -167,10 +168,12 @@ class MetricGaussianKL_MPI(Energy):
else
:
_samples
.
append
(((
i
%
2
)
*
2
-
1
)
*
met
.
draw_sample
(
from_inverse
=
True
))
met
.
draw_sample
(
from_inverse
=
True
,
dtype
=
lh_sampling_dtype
))
else
:
np
.
random
.
seed
(
i
+
seed_offset
)
_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
))
_samples
.
append
(
met
.
draw_sample
(
from_inverse
=
True
,
dtype
=
lh_sampling_dtype
))
np
.
random
.
set_state
(
rand_state
)
_samples
=
tuple
(
_samples
)
if
mirror_samples
:
...
...
@@ -196,12 +199,13 @@ class MetricGaussianKL_MPI(Energy):
self
.
_val
=
np_allreduce_sum
(
v
)[()]
/
self
.
_n_samples
self
.
_grad
=
allreduce_sum_field
(
g
)
/
self
.
_n_samples
self
.
_metric
=
None
self
.
_sampdt
=
lh_sampling_dtype
def
at
(
self
,
position
):
return
MetricGaussianKL_MPI
(
position
,
self
.
_hamiltonian
,
self
.
_n_samples
,
self
.
_constants
,
self
.
_point_estimates
,
_samples
=
self
.
_samples
,
seed_offset
=
self
.
_seed_offset
)
seed_offset
=
self
.
_seed_offset
,
lh_sampling_dtype
=
self
.
_sampdt
)
@
property
def
value
(
self
):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment