diff --git a/src/re/hmc.py b/src/re/hmc.py index fd7b00dd42319f925ecededfd84c672734419cbd..06f850d48a934a4ad8e18d13d8c9256361fe992f 100644 --- a/src/re/hmc.py +++ b/src/re/hmc.py @@ -10,7 +10,7 @@ from jax.experimental import io_callback from jax.scipy.special import expit from .lax import cond, fori_loop, while_loop -from .tree_math import random_like +from .tree_math import random_like, vdot _DEBUG_FLAG = False @@ -405,8 +405,8 @@ def is_euclidean_uturn(qp_left, qp_right): -------- Betancourt - A conceptual introduction to Hamiltonian Monte Carlo """ - return (qp_right.momentum.dot(qp_right.position - qp_left.position) < 0.0) & ( - qp_left.momentum.dot(qp_left.position - qp_right.position) < 0.0 + return (vdot(qp_right.momentum, qp_right.position - qp_left.position) < 0.0) & ( + vdot(qp_left.momentum, qp_left.position - qp_right.position) < 0.0 )