diff --git a/src/minimization/kl_energies.py b/src/minimization/kl_energies.py index 049c01d4d4750929eb8b809d301dd0b53225dfad..c46b975dedd8be7e5b4793d69253c666890d4434 100644 --- a/src/minimization/kl_energies.py +++ b/src/minimization/kl_energies.py @@ -199,6 +199,25 @@ class _SampledKLEnergy(Energy): yield -s +class _MetricGaussianSampler: + def __init__(self, position, H, n_samples, mirror_samples, napprox=0): + if not isinstance(H, StandardHamiltonian): + raise NotImplementedError + lin = Linearization.make_var(position.extract(H.domain), True) + self._met = H(lin).metric + if napprox >= 1: + self._met._approximation = makeOp(approximation2endo(met, napprox)) + self._n = int(n_samples) + + def draw_samples(self, comm): + local_samples = [] + sseq = random.spawn_sseq(self._n) + for i in range(*_get_lo_hi(comm, self._n)): + with random.Context(sseq[i]): + local_samples.append(self._met.draw_sample(from_inverse=True)) + return tuple(local_samples) + + class _GeoMetricSampler: def __init__(self, position, H, minimizer, start_from_lin, n_samples, mirror_samples, napprox=0, want_error=False): @@ -388,16 +407,10 @@ def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[], mirror_samples = bool(mirror_samples) _, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates) - lin = Linearization.make_var(mean.extract(ham_sampling.domain), True) - met = ham_sampling(lin).metric - if napprox >= 1: - met._approximation = makeOp(approximation2endo(met, napprox)) - local_samples = [] - sseq = random.spawn_sseq(n_samples) - for i in range(*_get_lo_hi(comm, n_samples)): - with random.Context(sseq[i]): - local_samples.append(met.draw_sample(from_inverse=True)) - local_samples = tuple(local_samples) + sampler = _MetricGaussianSampler(mean, ham_sampling, n_samples, + mirror_samples) + local_samples = sampler.draw_samples(comm) + mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants) return _SampledKLEnergy(mean, hamiltonian, n_samples, mirror_samples, comm, local_samples, nanisinf)