forest_util.py norm: Remove usage of list
Always use JAX arrays when calling JAX's norm implementation. This resolves an issue with the most recent version of JAX.
Always use JAX arrays when calling JAX's norm implementation. This resolves an issue with the most recent version of JAX.