sample mapping options
Implementation of sample mapping options for MetricKL() and for mean_value_and_grad(). Pre-implemented options selectable with string keys include jax.lax.map and jax.pmap, the latter for parallel mapping. Other mapping options beyond the pre-implemented ones can be directly passed as the corresponding function itself. Details on how to do this are provided in the documentation.