Skip to content
Snippets Groups Projects
Commit 12bc47ee authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

smap: Pass unmapped args via partial not carry

parent 77695e7f
No related branches found
No related tags found
1 merge request!842Introduce a more flexible sequential map
Pipeline #161134 passed
......@@ -12,12 +12,12 @@ def _int_or_none(x):
return isinstance(x, int) or x is None
def _fun_reord(unmapped, mapped, *, fun, unflatten, in_axes):
def _fun_reord(_, mapped, *, fun, unmapped, unflatten, in_axes):
un, mapped = list(unmapped), list(mapped)
assert len(un) + len(mapped) == len(in_axes)
args = tuple(un.pop(0) if a is None else mapped.pop(0) for a in in_axes)
y = fun(*unflatten(args))
return unmapped, y
return None, y
# The function over which to `scan` depends on the data. This leads to
......@@ -62,10 +62,11 @@ def _smap(fun, in_axes, out_axes, unroll, *x, **k):
fun_reord = partial(
_fun_reord,
fun=fun,
unmapped=unmapped,
unflatten=partial(tree_unflatten, x_td),
in_axes=in_axes
)
unmapped, y = lax.scan(fun_reord, unmapped, mapped, unroll=unroll)
_, y = lax.scan(fun_reord, None, mapped, unroll=unroll)
if isinstance(out_axes, int):
out_axes = tree_map(lambda _: out_axes, y)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment