Commit 4d60d6c4 authored by Philipp Arras's avatar Philipp Arras
Browse files

fixup! Port to geoVI compatible NIFTy_7

parent ada4010a
......@@ -79,8 +79,8 @@ if __name__ == "__main__":
position_fc = fullcov_model.get_initial_pos(initial_sig=0.01)
position_mf = meanfield_model.get_initial_pos(initial_sig=0.01)
KL_fc = ift.ParametricGaussianKL.make(position_fc, H, fullcov_model, 3, True)
KL_mf = ift.ParametricGaussianKL.make(position_mf, H, meanfield_model, 3, True)
KL_fc = ift.ParametricGaussianKL(position_fc, H, fullcov_model, 3, True)
KL_mf = ift.ParametricGaussianKL(position_mf, H, meanfield_model, 3, True)
plt.pause(0.001)
for i in range(25):
......
......@@ -522,8 +522,9 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
return _SampledKLEnergy(mean, hamiltonian, sampler.n_eff_samples, False,
comm, local_samples, nanisinf)
def ParametricGaussianKL(variational_parameters, hamiltonian, variational_model, n_samples, mirror_samples, comm=None, nanisinf=False):
def ParametricGaussianKL(variational_parameters, hamiltonian,
variational_model, n_samples, mirror_samples, comm=None,
nanisinf=False):
"""Provide the sampled Kullback-Leibler divergence between a distribution
and a Parametric Gaussian.
......@@ -538,6 +539,7 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
if not isinstance(mirror_samples, bool):
raise TypeError
from ..sugar import full, from_random
assert variational_model.generator.target is hamiltonian.domain
# FIXME self._variational_model = variational_model
# FIXME self._full_model = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment