Commit 3f1212d2 authored by Philipp Arras's avatar Philipp Arras
Browse files

Remove ParametricGaussianKL

parent 06b9edc2
Pipeline #103249 passed with stages
in 15 minutes and 15 seconds
......@@ -521,40 +521,3 @@ def GeoMetricKL(mean, hamiltonian, n_samples, minimizer_samp, mirror_samples,
mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
return _SampledKLEnergy(mean, hamiltonian, sampler.n_eff_samples, False,
comm, local_samples, nanisinf)
def ParametricGaussianKL(variational_parameters, hamiltonian,
variational_model, n_samples, mirror_samples, comm=None,
"""Provide the sampled Kullback-Leibler divergence between a distribution
and a Parametric Gaussian.
if not isinstance(hamiltonian, StandardHamiltonian):
raise TypeError
if hamiltonian.domain is not
raise ValueError
if not isinstance(n_samples, int):
raise TypeError
if not isinstance(mirror_samples, bool):
raise TypeError
from ..sugar import full, from_random
full_model = hamiltonian(variational_model.generator) + variational_model.entropy
local_samples = []
sseq = random.spawn_sseq(n_samples)
# FIXME dirty trick, many multiplications with zero
DirtyMaskDict = full(variational_model.generator.domain, 0.0).to_dict()
DirtyMaskDict["latent"] = full(
variational_model.generator.domain["latent"], 1.0
DirtyMask = MultiField.from_dict(DirtyMaskDict)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
s = DirtyMask * from_random(variational_model.generator.domain)
local_samples = tuple(local_samples)
return _SampledKLEnergy(variational_parameters, full_model, n_samples,
mirror_samples, comm, local_samples, nanisinf)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment