Commit eb60c29a by Philipp Arras

### Preconditioner for sampling

parent 5624dc3e
Pipeline #52349 passed with stages
in 8 minutes and 32 seconds
 ... ... @@ -18,6 +18,8 @@ from .. import utilities from ..linearization import Linearization from ..operators.energy_operators import StandardHamiltonian from ..probing import approximation2endo from ..sugar import makeOp from .energy import Energy ... ... @@ -72,7 +74,7 @@ class MetricGaussianKL(Energy): def __init__(self, mean, hamiltonian, n_samples, constants=[], point_estimates=[], mirror_samples=False, _samples=None): _samples=None, napprox=0): super(MetricGaussianKL, self).__init__(mean) if not isinstance(hamiltonian, StandardHamiltonian): ... ... @@ -91,6 +93,10 @@ class MetricGaussianKL(Energy): if _samples is None: met = hamiltonian(Linearization.make_partial_var( mean, point_estimates, True)).metric if napprox > 1: print('Calculate preconditioner for sampling') met._approximation = makeOp(approximation2endo(met, napprox)) print('Done') _samples = tuple(met.draw_sample(from_inverse=True) for _ in range(n_samples)) if mirror_samples: ... ... @@ -110,11 +116,12 @@ class MetricGaussianKL(Energy): self._val = v / len(self._samples) self._grad = g * (1./len(self._samples)) self._metric = None self._napprox = napprox def at(self, position): return MetricGaussianKL(position, self._hamiltonian, 0, self._constants, self._point_estimates, _samples=self._samples) _samples=self._samples, napprox=self._napprox) @property def value(self): ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment