Always use JAX arrays when calling JAX's norm implementation. This resolves an issue with the most recent version of JAX.