diff --git a/src/re/hmc_oo.py b/src/re/hmc_oo.py
index 84d80a41f2f2d0bbfd90e6615ac74dbfcd006e8d..4e71b54d56f971ac41912dd78d1e68ea190d4da7 100644
--- a/src/re/hmc_oo.py
+++ b/src/re/hmc_oo.py
@@ -17,6 +17,7 @@ from .hmc import (
     sample_momentum_from_diagonal,
     tree_index_update,
 )
+from .tree_math import vdot
 
 
 def _parse_diag_mass_matrix(mass_matrix, position_proto: Q) -> Q:
@@ -77,7 +78,7 @@ class _Sampler:
 
         def kinetic_energy(inverse_mass_matrix, momentum):
             # NOTE, assume a diagonal mass-matrix
-            return inverse_mass_matrix.dot(momentum**2) / 2.0
+            return vdot(inverse_mass_matrix, momentum**2 / 2.0)
 
         self.kinetic_energy = kinetic_energy
         kinetic_energy_gradient = lambda inv_m, mom: inv_m * mom