Skip to content
Snippets Groups Projects
Commit 357cd238 authored by Matteo Guardiani's avatar Matteo Guardiani
Browse files

408-host_callback_deprecated: address review.

parent 4e99866a
No related branches found
No related tags found
1 merge request!937Resolve "host_callback_depricated"
Pipeline #207168 passed
...@@ -6,7 +6,7 @@ from typing import Callable, NamedTuple, TypeVar, Union ...@@ -6,7 +6,7 @@ from typing import Callable, NamedTuple, TypeVar, Union
from jax import lax from jax import lax
from jax import numpy as jnp from jax import numpy as jnp
from jax import random, tree_util from jax import random, tree_util
from jax.experimental import host_callback from jax.debug import callback
from jax.scipy.special import expit from jax.scipy.special import expit
from .lax import cond, fori_loop, while_loop from .lax import cond, fori_loop, while_loop
...@@ -124,7 +124,7 @@ def leapfrog_step( ...@@ -124,7 +124,7 @@ def leapfrog_step(
global _DEBUG_FLAG global _DEBUG_FLAG
if _DEBUG_FLAG: if _DEBUG_FLAG:
# append result to global list variable # append result to global list variable
host_callback.call(_DEBUG_ADD_QP, qp_fullstep) callback(_DEBUG_ADD_QP, qp_fullstep)
return qp_fullstep return qp_fullstep
...@@ -357,7 +357,7 @@ def generate_nuts_tree( ...@@ -357,7 +357,7 @@ def generate_nuts_tree(
global _DEBUG_FLAG global _DEBUG_FLAG
if _DEBUG_FLAG: if _DEBUG_FLAG:
host_callback.call(_DEBUG_FINISH_TREE, None) callback(_DEBUG_FINISH_TREE, None)
return current_tree return current_tree
...@@ -544,7 +544,7 @@ def iterative_build_tree( ...@@ -544,7 +544,7 @@ def iterative_build_tree(
global _DEBUG_FLAG global _DEBUG_FLAG
if _DEBUG_FLAG: if _DEBUG_FLAG:
host_callback.call(_DEBUG_FINISH_SUBTREE, None) callback(_DEBUG_FINISH_SUBTREE, None)
# The depth of a tree which was aborted early is possibly ill defined # The depth of a tree which was aborted early is possibly ill defined
depth = jnp.where(n == max_num_proposals, depth, -1) depth = jnp.where(n == max_num_proposals, depth, -1)
......
...@@ -274,7 +274,7 @@ def _trust_ncg( ...@@ -274,7 +274,7 @@ def _trust_ncg(
subproblem_kwargs: Optional[Dict[str, Any]] = None, subproblem_kwargs: Optional[Dict[str, Any]] = None,
name: Optional[str] = None name: Optional[str] = None
) -> OptimizeResults: ) -> OptimizeResults:
from jax.experimental.host_callback import call from jax.debug import callback
maxiter = 200 if maxiter is None else maxiter maxiter = 200 if maxiter is None else maxiter
...@@ -421,7 +421,7 @@ def _trust_ncg( ...@@ -421,7 +421,7 @@ def _trust_ncg(
"nhev": params.nhev, "nhev": params.nhev,
"hit": sub_result.hits_boundary "hit": sub_result.hits_boundary
} }
call(pp, printable_state, result_shape=None) callback(pp, printable_state, result_shape=None)
return params return params
def _trust_region_cond_f(params: _TrustRegionState) -> bool: def _trust_region_cond_f(params: _TrustRegionState) -> bool:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment