diff --git a/ChangeLog.md b/ChangeLog.md
index 6f7b699d15e2301d15e41726e1b46890bbbf61fe..fb887b1a0b40c3826fe9372b42ad396f88e13d58 100644
--- a/ChangeLog.md
+++ b/ChangeLog.md
@@ -1,6 +1,12 @@
 Changes since NIFTy 8
 =====================
 
+`estimate_evidence_lower_bound`
+-------------
+
+Renamed `batch_size` to `n_batches` for clarity. Improved batch logic.
+
+
 Minimum Python version increased to 3.10
 
 Stabilize ICR at the cost of disallowing using old ICR reconstructions
diff --git a/src/evidence_lower_bound.py b/src/evidence_lower_bound.py
index 41e235142966764ddb64cc2264951538fce3ce8a..58e52ed30b81c1043e858d7d518537c779587a72 100644
--- a/src/evidence_lower_bound.py
+++ b/src/evidence_lower_bound.py
@@ -42,14 +42,15 @@ class _Projector(ssl.LinearOperator):
     Projector : LinearOperator
         Operator representing the projection.
     """
+
     def __init__(self, eigenvectors):
-        super().__init__(np.dtype('f8'), 2 * (eigenvectors.shape[0],))
+        super().__init__(eigenvectors.dtype, 2 * (eigenvectors.shape[0],))
         self.eigenvectors = eigenvectors
 
     def _matvec(self, x):
         res = x.copy()
         for eigenvector in self.eigenvectors.T:
-            res -= eigenvector * eigenvector.dot(x)
+            res -= eigenvector * np.vdot(eigenvector, x)
         return res
 
     def _rmatvec(self, x):
@@ -61,40 +62,76 @@ def _explicify(M):
     m = []
     for v in identity:
         m.append(M.matvec(v))
-    return np.vstack(m).T
-
-
-def _eigsh(metric, n_eigenvalues, tot_dofs, min_lh_eval=1e-4, batch_number=10, tol=0., verbose=True):
+    return np.column_stack(m)
+
+
+def _eigsh(
+    metric,
+    n_eigenvalues,
+    tot_dofs,
+    dtype,
+    min_lh_eval=1e-4,
+    n_batches=10,
+    tol=0.0,
+    verbose=True,
+):
     metric = SandwichOperator.make(_DomRemover(metric.domain).adjoint, metric)
     metric_size = metric.domain.size
-    M = ssl.LinearOperator(shape=2 * (metric_size,), matvec=lambda x: metric(makeField(metric.domain, x)).val)
+    M = ssl.LinearOperator(
+        shape=2 * (metric_size,),
+        matvec=lambda x: metric(makeField(metric.domain, x)).val,
+        dtype=dtype,
+    )
     eigenvectors = None
 
     if n_eigenvalues > tot_dofs:
-        raise ValueError("Number of requested eigenvalues exceeds the number of relevant degrees of freedom!")
+        raise ValueError(
+            "Number of requested eigenvalues exceeds the number of relevant degrees of freedom!"
+        )
 
     if tot_dofs == n_eigenvalues:
         # Compute exact eigensystem
         if verbose:
             logger.info(f"Computing all {tot_dofs} relevant metric eigenvalues.")
-        eigenvalues = slg.eigh(_explicify(M), eigvals_only=True,
-                               subset_by_index=[metric_size - tot_dofs, metric_size - 1])
+        eigenvalues = slg.eigh(
+            _explicify(M),
+            eigvals_only=True,
+            subset_by_index=[metric_size - tot_dofs, metric_size - 1],
+        )
         eigenvalues = np.flip(eigenvalues)
     else:
         # Set up batches
-        batch_size = n_eigenvalues // batch_number
-        batches = [batch_size, ] * (batch_number - 1)
-        batches += [n_eigenvalues - batch_size * (batch_number - 1), ]
+        batch_size = n_eigenvalues // n_batches
+        remainder = n_eigenvalues % batch_size
+        batches = [
+            batch_size,
+        ] * n_batches
+        batches += (
+            [
+                remainder,
+            ]
+            if remainder > 0
+            else []
+        )
         eigenvalues, projected_metric = None, M
+
         for batch in batches:
             if verbose:
                 logger.info(f"\nNumber of eigenvalues being computed: {batch}")
             # Get eigensystem for current batch
-            eigvals, eigvecs = ssl.eigsh(projected_metric, k=batch, tol=tol, return_eigenvectors=True, which='LM')
+            eigvals, eigvecs = ssl.eigsh(
+                projected_metric, k=batch, tol=tol, return_eigenvectors=True, which="LM"
+            )
             i = np.argsort(eigvals)
             eigvals, eigvecs = np.flip(eigvals[i]), np.flip(eigvecs[:, i], axis=1)
-            eigenvalues = eigvals if eigenvalues is None else np.concatenate((eigenvalues, eigvals))
-            eigenvectors = eigvecs if eigenvectors is None else np.hstack((eigenvectors, eigvecs))
+            eigenvalues = (
+                eigvals
+                if eigenvalues is None
+                else np.concatenate((eigenvalues, eigvals))
+            )
+            eigenvectors = (
+                eigvecs if eigenvectors is None else np.hstack((eigenvectors, eigvecs))
+            )
 
             if abs(1.0 - np.min(eigenvalues)) < min_lh_eval:
                 break
@@ -104,8 +141,17 @@ def _eigsh(metric, n_eigenvalues, tot_dofs, min_lh_eval=1e-4, batch_number=10, t
     return eigenvalues, eigenvectors
 
 
-def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_eval=1e-3, batch_number=10, tol=0.,
-                                  verbose=True):
+def estimate_evidence_lower_bound(
+    hamiltonian,
+    samples,
+    n_eigenvalues,
+    *,
+    min_lh_eval=1e-3,
+    n_batches=10,
+    tol=0.0,
+    verbose=True,
+    dtype=np.float64,
+):
     """Provides an estimate for the Evidence Lower Bound (ELBO).
 
     Statistical inference deals with the problem of hypothesis testing, given
@@ -161,7 +207,7 @@ def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_ev
         eigenvalue estimation terminates and uses the smallest eigenvalue as a
         proxy for all remaining eigenvalues in the trace-log estimation.
         Default is 1e-3.
-    batch_number : int
+    n_batches : int
         Number of batches into which the eigenvalue estimation gets subdivided
         into. Only after completing one batch the early stopping criterion
         based on `min_lh_eval` is checked for.
@@ -171,6 +217,8 @@ def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_ev
     verbose : Optional[bool]
         Print list of eigenvalues and summary of evidence calculation. Default
         is True.
+    dtype : Optional[numpy.dtype]
+        Data type of the eigenvalues and eigenvectors. Default is np.float64.
 
     Returns
     -------
@@ -218,36 +266,58 @@ def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_ev
     if not isinstance(hamiltonian, StandardHamiltonian):
         raise TypeError("hamiltonian is not an instance of `ift.StandardHamiltonian`.")
 
-    n_data_points = hamiltonian.likelihood_energy.data_domain.size if not None else hamiltonian.domain.size
-    n_relevant_dofs = min(n_data_points, hamiltonian.domain.size)  # Number of metric eigenvalues that are not one
+    n_data_points = (
+        hamiltonian.likelihood_energy.data_domain.size
+        if not None
+        else hamiltonian.domain.size
+    )
+    n_relevant_dofs = min(
+        n_data_points, hamiltonian.domain.size
+    )  # Number of metric eigenvalues that are not one
 
     metric = hamiltonian(Linearization.make_var(samples._m, want_metric=True)).metric
     metric_size = metric.domain.size
-    eigenvalues, _ = _eigsh(metric, n_eigenvalues, tot_dofs=n_relevant_dofs, min_lh_eval=min_lh_eval,
-                            batch_number=batch_number, tol=tol, verbose=verbose)
+
+    eigenvalues, _ = _eigsh(
+        metric,
+        n_eigenvalues,
+        n_relevant_dofs,
+        dtype,
+        min_lh_eval=min_lh_eval,
+        n_batches=n_batches,
+        tol=tol,
+        verbose=verbose,
+    )
     if verbose:
         # FIXME
         logger.info(
             f"\nComputed {eigenvalues.size} largest eigenvalues (out of {n_relevant_dofs} relevant degrees of freedom)."
             f"\nThe remaining {metric_size - n_relevant_dofs} metric eigenvalues (out of {metric_size}) are equal to "
-            f"1.\n\n{eigenvalues}.")
+            f"1.\n\n{eigenvalues}."
+        )
 
     # Return a list of ELBO samples and a summary of the ELBO statistics
     log_eigenvalues = np.log(eigenvalues)
-    tr_log_lat_cov = - 0.5 * np.sum(log_eigenvalues)
-    tr_log_lat_cov_lower = 0.5 * (n_relevant_dofs - log_eigenvalues.size) * np.min(log_eigenvalues)
+    tr_log_lat_cov = -0.5 * np.sum(log_eigenvalues)
+    tr_log_lat_cov_lower = (
+        0.5 * (n_relevant_dofs - log_eigenvalues.size) * np.min(log_eigenvalues)
+    )
     tr_log_lat_cov_lower = Field.scalar(tr_log_lat_cov_lower)
     posterior_contribution = Field.scalar(tr_log_lat_cov + 0.5 * metric_size)
-    elbo_samples = SampleList(list(samples.iterator(lambda x: posterior_contribution - hamiltonian(x))))
+    elbo_samples = SampleList(
+        list(samples.iterator(lambda x: posterior_contribution - hamiltonian(x)))
+    )
 
-    stats = {'lower_error': tr_log_lat_cov_lower}
+    stats = {"lower_error": tr_log_lat_cov_lower}
     elbo_mean, elbo_var = elbo_samples.sample_stat()
     elbo_up = elbo_mean + elbo_var.sqrt()
     elbo_lw = elbo_mean - elbo_var.sqrt() - stats["lower_error"]
-    stats['elbo_mean'], stats['elbo_up'], stats['elbo_lw'] = elbo_mean, elbo_up, elbo_lw
+    stats["elbo_mean"], stats["elbo_up"], stats["elbo_lw"] = elbo_mean, elbo_up, elbo_lw
     if verbose:
-        s = (f"\nELBO decomposition (in log units)"
-             f"\nELBO mean : {elbo_mean.val:.4e} (upper: {elbo_up.val:.4e}, lower: {elbo_lw.val:.4e})")
+        s = (
+            f"\nELBO decomposition (in log units)"
+            f"\nELBO mean : {elbo_mean.val:.4e} (upper: {elbo_up.val:.4e}, lower: {elbo_lw.val:.4e})"
+        )
         logger.info(s)
 
     return elbo_samples, stats
diff --git a/src/re/evidence_lower_bound.py b/src/re/evidence_lower_bound.py
index 20c7bf11d33662dacf89bc6d4123041b7c9b6fcc..f35921ece1d00da44392ef4278470de35ebaa1d1 100644
--- a/src/re/evidence_lower_bound.py
+++ b/src/re/evidence_lower_bound.py
@@ -22,11 +22,6 @@ class _Projector(ssl.LinearOperator):
     ----------
     eigenvectors : ndarray
         The eigenvectors representing the directions to project out.
-
-    Returns
-    -------
-    Projector : LinearOperator
-        Operator representing the projection.
     """
 
     def __init__(self, eigenvectors):
@@ -44,35 +39,33 @@ class _Projector(ssl.LinearOperator):
 
 
 def _explicify(M):
-    n = M.shape[0]
-    m = []
-    for i in range(n):
-        basis_vector = np.zeros(n)
-        basis_vector[i] = 1
-        m.append(M @ basis_vector)
-    return np.stack(m, axis=1)
+    identity = np.identity(M.shape[0], dtype=np.float64)
+    return np.column_stack([M.matvec(v) for v in identity])
 
 
 def _ravel_metric(metric, position, dtype):
-    shape = 2 * (size(metric(position, position)),)
+    def ravel(x):
+        return jax.flatten_util.ravel_pytree(x)[0]
 
-    ravel = lambda x: jax.flatten_util.ravel_pytree(x)[0]
-    unravel = lambda x: jax.linear_transpose(ravel, position)(x)[0]
-    met = lambda x: ravel(metric(position, unravel(x)))
+    shp, unravel = jax.flatten_util.ravel_pytree(position)
+    shape = 2 * (shp.size,)
+
+    def met(x):
+        return ravel(metric(position, unravel(x)))
 
     return ssl.LinearOperator(shape=shape, dtype=dtype, matvec=met)
 
 
 def _eigsh(
     metric,
+    metric_size,
     n_eigenvalues,
     tot_dofs,
     min_lh_eval=1e-4,
-    batch_size=10,
+    n_batches=10,
     tol=0.0,
     verbose=True,
 ):
-    metric_size = metric.shape[0]
     eigenvectors = None
     if n_eigenvalues > tot_dofs:
         raise ValueError(
@@ -83,7 +76,7 @@ def _eigsh(
     if tot_dofs == n_eigenvalues:
         # Compute exact eigensystem
         if verbose:
-            logger.info(f"Computing all {tot_dofs} relevant " f"metric eigenvalues.")
+            logger.info(f"Computing all {tot_dofs} relevant metric eigenvalues.")
         eigenvalues = slg.eigh(
             _explicify(metric),
             eigvals_only=True,
@@ -92,14 +85,12 @@ def _eigsh(
         eigenvalues = np.flip(eigenvalues)
     else:
         # Set up batches
-        batch_size = n_eigenvalues // batch_size
-        batches = [
-            batch_size,
-        ] * (batch_size - 1)
-        batches += [
-            n_eigenvalues - batch_size * (batch_size - 1),
-        ]
+        batch_size = n_eigenvalues // n_batches
+        remainder = n_eigenvalues % batch_size
+        batches = [batch_size] * n_batches
+        batches += [remainder] if remainder > 0 else []
         eigenvalues, projected_metric = None, metric
+
         for batch in batches:
             if verbose:
                 logger.info(f"\nNumber of eigenvalues being computed: {batch}")
@@ -130,8 +121,9 @@ def estimate_evidence_lower_bound(
     likelihood,
     samples,
     n_eigenvalues,
+    *,
     min_lh_eval=1e-3,
-    batch_size=10,
+    n_batches=10,
     tol=0.0,
     verbose=True,
 ):
@@ -189,7 +181,7 @@ def estimate_evidence_lower_bound(
         eigenvalue estimation terminates and uses the smallest eigenvalue as a
         proxy for all remaining eigenvalues in the trace-log estimation.
         Default is 1e-3.
-    batch_size : int
+    n_batches : int
         Number of batches into which the eigenvalue estimation gets subdivided
         into. Only after completing one batch the early stopping criterion
         based on `min_lh_eval` is checked for.
@@ -250,19 +242,18 @@ def estimate_evidence_lower_bound(
 
     hamiltonian = StandardHamiltonian(likelihood)
     metric = hamiltonian.metric
+    metric_size = jax.flatten_util.ravel_pytree(samples.pos)[0].size
     metric = _ravel_metric(metric, samples.pos, dtype=likelihood.target.dtype)
-    metric_size = metric.shape[0]
     n_data_points = size(likelihood.lsm_tangents_shape) if not None else metric_size
-    n_relevant_dofs = min(
-        n_data_points, metric_size
-    )  # Number of metric eigenvalues that are not 1
+    n_relevant_dofs = min(n_data_points, metric_size)
 
     eigenvalues, _ = _eigsh(
         metric,
+        metric_size,
         n_eigenvalues,
         tot_dofs=n_relevant_dofs,
         min_lh_eval=min_lh_eval,
-        batch_size=batch_size,
+        n_batches=n_batches,
         tol=tol,
         verbose=verbose,
     )
diff --git a/test/test_re/test_estimate_evidence_lower_bound.py b/test/test_re/test_estimate_evidence_lower_bound.py
index 072f1992a53a19b3f8b7636d3715ca9fe0048b9b..03dba313a3737de50fa38bfb671eba34b6da64a0 100644
--- a/test/test_re/test_estimate_evidence_lower_bound.py
+++ b/test/test_re/test_estimate_evidence_lower_bound.py
@@ -24,10 +24,7 @@ def _explicify(M, position):
     unravel = lambda x: jax.linear_transpose(ravel, position)(x)[0]
     mat = lambda x: M(unravel(x))
     identity = np.identity(dim, dtype=np.float64)
-    m = []
-    for v in identity:
-        m.append(mat(v))
-    return np.vstack(m).T
+    return np.column_stack([mat(v) for v in identity])
 
 
 def get_linear_response(slope_op, intercept_op, sampling_points):
@@ -240,10 +237,8 @@ def test_estimate_elbo_nifty_re_vs_nifty(seed):
 
     n_ham = ift.StandardHamiltonian(n_like)
 
-    elbo, stats = jft.estimate_evidence_lower_bound(like, samples, 4, batch_size=2)
-    n_elbo, nstats = ift.estimate_evidence_lower_bound(
-        n_ham, n_samples, 4, batch_number=2
-    )
+    elbo, stats = jft.estimate_evidence_lower_bound(like, samples, 4, n_batches=2)
+    n_elbo, nstats = ift.estimate_evidence_lower_bound(n_ham, n_samples, 4, n_batches=2)
 
     n_elbo_samples = []
     for n_elbo_sample in n_elbo.iterator():