Skip to content
Snippets Groups Projects
Commit 357cd238 authored by Matteo Guardiani's avatar Matteo Guardiani
Browse files

408-host_callback_deprecated: address review.

parent 4e99866a
No related branches found
No related tags found
1 merge request!937Resolve "host_callback_depricated"
Pipeline #207168 passed
......@@ -6,7 +6,7 @@ from typing import Callable, NamedTuple, TypeVar, Union
from jax import lax
from jax import numpy as jnp
from jax import random, tree_util
from jax.experimental import host_callback
from jax.debug import callback
from jax.scipy.special import expit
from .lax import cond, fori_loop, while_loop
......@@ -124,7 +124,7 @@ def leapfrog_step(
global _DEBUG_FLAG
if _DEBUG_FLAG:
# append result to global list variable
host_callback.call(_DEBUG_ADD_QP, qp_fullstep)
callback(_DEBUG_ADD_QP, qp_fullstep)
return qp_fullstep
......@@ -357,7 +357,7 @@ def generate_nuts_tree(
global _DEBUG_FLAG
if _DEBUG_FLAG:
host_callback.call(_DEBUG_FINISH_TREE, None)
callback(_DEBUG_FINISH_TREE, None)
return current_tree
......@@ -544,7 +544,7 @@ def iterative_build_tree(
global _DEBUG_FLAG
if _DEBUG_FLAG:
host_callback.call(_DEBUG_FINISH_SUBTREE, None)
callback(_DEBUG_FINISH_SUBTREE, None)
# The depth of a tree which was aborted early is possibly ill defined
depth = jnp.where(n == max_num_proposals, depth, -1)
......
......@@ -274,7 +274,7 @@ def _trust_ncg(
subproblem_kwargs: Optional[Dict[str, Any]] = None,
name: Optional[str] = None
) -> OptimizeResults:
from jax.experimental.host_callback import call
from jax.debug import callback
maxiter = 200 if maxiter is None else maxiter
......@@ -421,7 +421,7 @@ def _trust_ncg(
"nhev": params.nhev,
"hit": sub_result.hits_boundary
}
call(pp, printable_state, result_shape=None)
callback(pp, printable_state, result_shape=None)
return params
def _trust_region_cond_f(params: _TrustRegionState) -> bool:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment