Jax: Actually return derivative and not tuple of derivative

15 jobs for jax_fixes in 20 minutes and 59 seconds (queued for 1 second)