remove depricated `jax.tree_map`

Merge request reports

Loading