diff --git a/src/re/conjugate_gradient.py b/src/re/conjugate_gradient.py index ba6e322b52abea09e9be3e0136dcd675f9e9cfc1..b73f4d1fe6d50b0a032af0c19240ffb9d642c73a 100644 --- a/src/re/conjugate_gradient.py +++ b/src/re/conjugate_gradient.py @@ -8,10 +8,9 @@ import jax from jax import numpy as jnp from .logger import logger -from .misc import doc_from -from .tree_math import assert_arithmetics +from .misc import doc_from, safeguard_arguments_against_accidental_calls_into_jax +from .tree_math import assert_arithmetics, result_type, size, vdot, where, zeros_like from .tree_math import norm as jft_norm -from .tree_math import result_type, size, vdot, where, zeros_like HessVP = Callable[[jnp.ndarray], jnp.ndarray] @@ -51,6 +50,7 @@ def static_cg(mat, j, x0=None, *args, **kwargs): return cg_res.x, cg_res.info +@safeguard_arguments_against_accidental_calls_into_jax def _cg_pretty_print_it( name, i, @@ -520,6 +520,7 @@ def _cg_steihaug_subproblem( # and hessian soa = partial(second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk) + @safeguard_arguments_against_accidental_calls_into_jax def pp(arg): msg = ( "{name}: |∇|:{r_norm:.6e} ➽:{resnorm:.6e} ↗:{tr:.6e}" diff --git a/src/re/misc.py b/src/re/misc.py index 5aa01f43ba5e9f2fa5d8ae81ee6209eead1cbad0..5a8c3d6081a6d46071d1ec861559aee6e73b1b4a 100644 --- a/src/re/misc.py +++ b/src/re/misc.py @@ -5,6 +5,7 @@ from functools import wraps from typing import Any, Callable, Dict, Hashable, Mapping, ParamSpec, TypeVar import jax +import numpy as np from jax import numpy as jnp from jax.tree_util import Partial @@ -95,8 +96,6 @@ def interpolate(xmin=-7.0, xmax=7.0, N=14000) -> Callable: """ def decorator(f): - from functools import wraps - x = jnp.linspace(xmin, xmax, N) y = f(x) @@ -109,6 +108,29 @@ def interpolate(xmin=-7.0, xmax=7.0, N=14000) -> Callable: return decorator +def _to_numpy(arg): + # NOTE, assume no cycles in the input and recurse into it without safeguards + if isinstance(arg, jax.Array): + return np.asarray(arg) + elif isinstance(arg, dict): + # JAX arrays are not hashable so keys do not need to be checked + return {k: _to_numpy(v) for k, v in arg.items()} + elif isinstance(arg, (tuple, list)): + type(arg)(map(_to_numpy, arg)) + return arg + + +def safeguard_arguments_against_accidental_calls_into_jax(func): + """Safeguard against using JAX in callback, see !25861""" + + @wraps(func) + def safe_func(*args, **kwargs): + return func(*_to_numpy(args), **_to_numpy(kwargs)) + + return safe_func + + +@safeguard_arguments_against_accidental_calls_into_jax def _maybe_raise(condition, exception): if condition: raise exception() @@ -137,6 +159,7 @@ def conditional_raise(condition: bool, exception): ) +@safeguard_arguments_against_accidental_calls_into_jax def _maybe_call(condition, fn, args, kwargs): if condition: fn(*args, **kwargs) @@ -145,6 +168,12 @@ def _maybe_call(condition, fn, args, kwargs): def conditional_call(condition, fn, *args, **kwargs): """JAX JIT-safe call to `fn` if `condition` is True. + Warning: + --------- + The function `fn` may NOT dispatch ANY JAX code, including by comparing JAX + objects! To safeguard against easy to miss fallacies, all JAX arrays are + automatically converted to numpy arrays. + Parameters: ----------- condition: boolean diff --git a/src/re/optimize.py b/src/re/optimize.py index 70876fe84acd978cf6d4cc5c2d9fcf8cc6e47fa3..80a8079c006b205c54baa2574a1c41f6c69fbc8e 100644 --- a/src/re/optimize.py +++ b/src/re/optimize.py @@ -10,10 +10,22 @@ from jax.tree_util import Partial from . import conjugate_gradient from .logger import logger -from .misc import conditional_call, conditional_raise, doc_from -from .tree_math import PyTreeString, assert_arithmetics, hide_strings +from .misc import ( + conditional_call, + conditional_raise, + doc_from, + safeguard_arguments_against_accidental_calls_into_jax, +) +from .tree_math import ( + PyTreeString, + assert_arithmetics, + hide_strings, + result_type, + size, + vdot, + where, +) from .tree_math import norm as jft_norm -from .tree_math import result_type, size, vdot, where class OptimizeResults(NamedTuple): @@ -106,6 +118,7 @@ def static_newton_cg(fun=None, x0=None, *args, **kwargs): return _static_newton_cg(fun, x0, *args, **kwargs).x +@safeguard_arguments_against_accidental_calls_into_jax def _ncg_pretty_print_it( name, i, @@ -584,6 +597,7 @@ def _trust_ncg( old_fval=old_fval, ) + @safeguard_arguments_against_accidental_calls_into_jax def pp(arg): i = arg["i"] msg = (