smap axis ordering
There seems to be an error with the axis ordering in smap for the corner case of multiple size 1 axes. Here is a minimal example that compares to vmap:
import nifty8.re as jft
import jax
import jax.numpy as jnp
def simple_function(x):
return x
a = jnp.arange(100)[..., jnp.newaxis, jnp.newaxis]
print(jft.smap(simple_function, in_axes=(2), out_axes=(1))(a).shape)
print(jax.vmap(simple_function, in_axes=(2), out_axes=(1))(a).shape)
# (1, 1, 100)
# (100, 1, 1)
Edited by Laurin Söding