diff --git a/src/re/hmc_oo.py b/src/re/hmc_oo.py index 84d80a41f2f2d0bbfd90e6615ac74dbfcd006e8d..4e71b54d56f971ac41912dd78d1e68ea190d4da7 100644 --- a/src/re/hmc_oo.py +++ b/src/re/hmc_oo.py @@ -17,6 +17,7 @@ from .hmc import ( sample_momentum_from_diagonal, tree_index_update, ) +from .tree_math import vdot def _parse_diag_mass_matrix(mass_matrix, position_proto: Q) -> Q: @@ -77,7 +78,7 @@ class _Sampler: def kinetic_energy(inverse_mass_matrix, momentum): # NOTE, assume a diagonal mass-matrix - return inverse_mass_matrix.dot(momentum**2) / 2.0 + return vdot(inverse_mass_matrix, momentum**2 / 2.0) self.kinetic_energy = kinetic_energy kinetic_energy_gradient = lambda inv_m, mom: inv_m * mom