Commit 3f1212d2 authored by Philipp Arras's avatar Philipp Arras
Browse files

Remove ParametricGaussianKL

parent 06b9edc2
......@@ -521,40 +521,3 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
return _SampledKLEnergy(mean, hamiltonian, sampler.n_eff_samples, False,
comm, local_samples, nanisinf)
def ParametricGaussianKL(variational_parameters, hamiltonian,
variational_model, n_samples, mirror_samples, comm=None,
nanisinf=False):
"""Provide the sampled Kullback-Leibler divergence between a distribution
and a Parametric Gaussian.
FIXME
"""
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not variational_model.generator.target:
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
if not isinstance(mirror_samples, bool):
raise TypeError
from ..sugar import full, from_random
full_model = hamiltonian(variational_model.generator) + variational_model.entropy
local_samples = []
sseq = random.spawn_sseq(n_samples)
# FIXME dirty trick, many multiplications with zero
DirtyMaskDict = full(variational_model.generator.domain, 0.0).to_dict()
DirtyMaskDict["latent"] = full(
variational_model.generator.domain["latent"], 1.0
)
DirtyMask = MultiField.from_dict(DirtyMaskDict)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
s = DirtyMask * from_random(variational_model.generator.domain)
local_samples.append(s)
local_samples = tuple(local_samples)
return _SampledKLEnergy(variational_parameters, full_model, n_samples,
mirror_samples, comm, local_samples, nanisinf)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment