Commit 5f245882 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add docu to KL

parent eb60c29a
...@@ -58,6 +58,9 @@ class MetricGaussianKL(Energy): ...@@ -58,6 +58,9 @@ class MetricGaussianKL(Energy):
as they are equally legitimate samples. If true, the number of used as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False. extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
_samples : None _samples : None
Only a parameter for internal uses. Typically not to be set by users. Only a parameter for internal uses. Typically not to be set by users.
...@@ -74,7 +77,7 @@ class MetricGaussianKL(Energy): ...@@ -74,7 +77,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False, point_estimates=[], mirror_samples=False,
_samples=None, napprox=0): napprox=0, _samples=None):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
if not isinstance(hamiltonian, StandardHamiltonian): if not isinstance(hamiltonian, StandardHamiltonian):
...@@ -94,9 +97,7 @@ class MetricGaussianKL(Energy): ...@@ -94,9 +97,7 @@ class MetricGaussianKL(Energy):
met = hamiltonian(Linearization.make_partial_var( met = hamiltonian(Linearization.make_partial_var(
mean, point_estimates, True)).metric mean, point_estimates, True)).metric
if napprox > 1: if napprox > 1:
print('Calculate preconditioner for sampling')
met._approximation = makeOp(approximation2endo(met, napprox)) met._approximation = makeOp(approximation2endo(met, napprox))
print('Done')
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(n_samples)) for _ in range(n_samples))
if mirror_samples: if mirror_samples:
...@@ -121,7 +122,7 @@ class MetricGaussianKL(Energy): ...@@ -121,7 +122,7 @@ class MetricGaussianKL(Energy):
def at(self, position): def at(self, position):
return MetricGaussianKL(position, self._hamiltonian, 0, return MetricGaussianKL(position, self._hamiltonian, 0,
self._constants, self._point_estimates, self._constants, self._point_estimates,
_samples=self._samples, napprox=self._napprox) napprox=self._napprox, _samples=self._samples)
@property @property
def value(self): def value(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment