Commit e60207e1 authored by Philipp Arras's avatar Philipp Arras
Browse files

Unify MGVI and GeoVI implementation

parent 3de57d9c
......@@ -199,6 +199,25 @@ class _SampledKLEnergy(Energy):
yield -s
class _MetricGaussianSampler:
def __init__(self, position, H, n_samples, mirror_samples, napprox=0):
if not isinstance(H, StandardHamiltonian):
raise NotImplementedError
lin = Linearization.make_var(position.extract(H.domain), True)
self._met = H(lin).metric
if napprox >= 1:
self._met._approximation = makeOp(approximation2endo(met, napprox))
self._n = int(n_samples)
def draw_samples(self, comm):
local_samples = []
sseq = random.spawn_sseq(self._n)
for i in range(*_get_lo_hi(comm, self._n)):
with random.Context(sseq[i]):
return tuple(local_samples)
class _GeoMetricSampler:
def __init__(self, position, H, minimizer, start_from_lin,
n_samples, mirror_samples, napprox=0, want_error=False):
......@@ -388,16 +407,10 @@ def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
mirror_samples = bool(mirror_samples)
_, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates)
lin = Linearization.make_var(mean.extract(ham_sampling.domain), True)
met = ham_sampling(lin).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
local_samples = []
sseq = random.spawn_sseq(n_samples)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
local_samples = tuple(local_samples)
sampler = _MetricGaussianSampler(mean, ham_sampling, n_samples,
local_samples = sampler.draw_samples(comm)
mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
return _SampledKLEnergy(mean, hamiltonian, n_samples, mirror_samples, comm,
local_samples, nanisinf)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment