diff --git a/src/minimization/kl_energies.py b/src/minimization/kl_energies.py
index 049c01d4d4750929eb8b809d301dd0b53225dfad..c46b975dedd8be7e5b4793d69253c666890d4434 100644
--- a/src/minimization/kl_energies.py
+++ b/src/minimization/kl_energies.py
@@ -199,6 +199,25 @@ class _SampledKLEnergy(Energy):
                         yield -s
 
 
+class _MetricGaussianSampler:
+    def __init__(self, position, H, n_samples, mirror_samples, napprox=0):
+        if not isinstance(H, StandardHamiltonian):
+            raise NotImplementedError
+        lin = Linearization.make_var(position.extract(H.domain), True)
+        self._met = H(lin).metric
+        if napprox >= 1:
+            self._met._approximation = makeOp(approximation2endo(met, napprox))
+        self._n = int(n_samples)
+
+    def draw_samples(self, comm):
+        local_samples = []
+        sseq = random.spawn_sseq(self._n)
+        for i in range(*_get_lo_hi(comm, self._n)):
+            with random.Context(sseq[i]):
+                local_samples.append(self._met.draw_sample(from_inverse=True))
+        return tuple(local_samples)
+
+
 class _GeoMetricSampler:
     def __init__(self, position, H, minimizer, start_from_lin,
                  n_samples, mirror_samples, napprox=0, want_error=False):
@@ -388,16 +407,10 @@ def MetricGaussianKL(mean, hamiltonian, n_samples, mirror_samples, constants=[],
     mirror_samples = bool(mirror_samples)
 
     _, ham_sampling = _reduce_by_keys(mean, hamiltonian, point_estimates)
-    lin = Linearization.make_var(mean.extract(ham_sampling.domain), True)
-    met = ham_sampling(lin).metric
-    if napprox >= 1:
-        met._approximation = makeOp(approximation2endo(met, napprox))
-    local_samples = []
-    sseq = random.spawn_sseq(n_samples)
-    for i in range(*_get_lo_hi(comm, n_samples)):
-        with random.Context(sseq[i]):
-            local_samples.append(met.draw_sample(from_inverse=True))
-    local_samples = tuple(local_samples)
+    sampler = _MetricGaussianSampler(mean, ham_sampling, n_samples,
+                                     mirror_samples)
+    local_samples = sampler.draw_samples(comm)
+
     mean, hamiltonian = _reduce_by_keys(mean, hamiltonian, constants)
     return _SampledKLEnergy(mean, hamiltonian, n_samples, mirror_samples, comm,
                             local_samples, nanisinf)