diff --git a/src/re/conjugate_gradient.py b/src/re/conjugate_gradient.py
index ba6e322b52abea09e9be3e0136dcd675f9e9cfc1..b73f4d1fe6d50b0a032af0c19240ffb9d642c73a 100644
--- a/src/re/conjugate_gradient.py
+++ b/src/re/conjugate_gradient.py
@@ -8,10 +8,9 @@ import jax
 from jax import numpy as jnp
 
 from .logger import logger
-from .misc import doc_from
-from .tree_math import assert_arithmetics
+from .misc import doc_from, safeguard_arguments_against_accidental_calls_into_jax
+from .tree_math import assert_arithmetics, result_type, size, vdot, where, zeros_like
 from .tree_math import norm as jft_norm
-from .tree_math import result_type, size, vdot, where, zeros_like
 
 HessVP = Callable[[jnp.ndarray], jnp.ndarray]
 
@@ -51,6 +50,7 @@ def static_cg(mat, j, x0=None, *args, **kwargs):
     return cg_res.x, cg_res.info
 
 
+@safeguard_arguments_against_accidental_calls_into_jax
 def _cg_pretty_print_it(
     name,
     i,
@@ -520,6 +520,7 @@ def _cg_steihaug_subproblem(
     # and hessian
     soa = partial(second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk)
 
+    @safeguard_arguments_against_accidental_calls_into_jax
     def pp(arg):
         msg = (
             "{name}: |∇|:{r_norm:.6e} ➽:{resnorm:.6e} ↗:{tr:.6e}"
diff --git a/src/re/misc.py b/src/re/misc.py
index 5aa01f43ba5e9f2fa5d8ae81ee6209eead1cbad0..5a8c3d6081a6d46071d1ec861559aee6e73b1b4a 100644
--- a/src/re/misc.py
+++ b/src/re/misc.py
@@ -5,6 +5,7 @@ from functools import wraps
 from typing import Any, Callable, Dict, Hashable, Mapping, ParamSpec, TypeVar
 
 import jax
+import numpy as np
 from jax import numpy as jnp
 from jax.tree_util import Partial
 
@@ -95,8 +96,6 @@ def interpolate(xmin=-7.0, xmax=7.0, N=14000) -> Callable:
     """
 
     def decorator(f):
-        from functools import wraps
-
         x = jnp.linspace(xmin, xmax, N)
         y = f(x)
 
@@ -109,6 +108,29 @@ def interpolate(xmin=-7.0, xmax=7.0, N=14000) -> Callable:
     return decorator
 
 
+def _to_numpy(arg):
+    # NOTE, assume no cycles in the input and recurse into it without safeguards
+    if isinstance(arg, jax.Array):
+        return np.asarray(arg)
+    elif isinstance(arg, dict):
+        # JAX arrays are not hashable so keys do not need to be checked
+        return {k: _to_numpy(v) for k, v in arg.items()}
+    elif isinstance(arg, (tuple, list)):
+        type(arg)(map(_to_numpy, arg))
+    return arg
+
+
+def safeguard_arguments_against_accidental_calls_into_jax(func):
+    """Safeguard against using JAX in callback, see !25861"""
+
+    @wraps(func)
+    def safe_func(*args, **kwargs):
+        return func(*_to_numpy(args), **_to_numpy(kwargs))
+
+    return safe_func
+
+
+@safeguard_arguments_against_accidental_calls_into_jax
 def _maybe_raise(condition, exception):
     if condition:
         raise exception()
@@ -137,6 +159,7 @@ def conditional_raise(condition: bool, exception):
     )
 
 
+@safeguard_arguments_against_accidental_calls_into_jax
 def _maybe_call(condition, fn, args, kwargs):
     if condition:
         fn(*args, **kwargs)
@@ -145,6 +168,12 @@ def _maybe_call(condition, fn, args, kwargs):
 def conditional_call(condition, fn, *args, **kwargs):
     """JAX JIT-safe call to `fn` if `condition` is True.
 
+    Warning:
+    ---------
+    The function `fn` may NOT dispatch ANY JAX code, including by comparing JAX
+    objects! To safeguard against easy to miss fallacies, all JAX arrays are
+    automatically converted to numpy arrays.
+
     Parameters:
     -----------
     condition: boolean
diff --git a/src/re/optimize.py b/src/re/optimize.py
index 70876fe84acd978cf6d4cc5c2d9fcf8cc6e47fa3..80a8079c006b205c54baa2574a1c41f6c69fbc8e 100644
--- a/src/re/optimize.py
+++ b/src/re/optimize.py
@@ -10,10 +10,22 @@ from jax.tree_util import Partial
 
 from . import conjugate_gradient
 from .logger import logger
-from .misc import conditional_call, conditional_raise, doc_from
-from .tree_math import PyTreeString, assert_arithmetics, hide_strings
+from .misc import (
+    conditional_call,
+    conditional_raise,
+    doc_from,
+    safeguard_arguments_against_accidental_calls_into_jax,
+)
+from .tree_math import (
+    PyTreeString,
+    assert_arithmetics,
+    hide_strings,
+    result_type,
+    size,
+    vdot,
+    where,
+)
 from .tree_math import norm as jft_norm
-from .tree_math import result_type, size, vdot, where
 
 
 class OptimizeResults(NamedTuple):
@@ -106,6 +118,7 @@ def static_newton_cg(fun=None, x0=None, *args, **kwargs):
     return _static_newton_cg(fun, x0, *args, **kwargs).x
 
 
+@safeguard_arguments_against_accidental_calls_into_jax
 def _ncg_pretty_print_it(
     name,
     i,
@@ -584,6 +597,7 @@ def _trust_ncg(
         old_fval=old_fval,
     )
 
+    @safeguard_arguments_against_accidental_calls_into_jax
     def pp(arg):
         i = arg["i"]
         msg = (