Non-jitted model evalutation in optmize_kl
I have a use-case where I want to evaluate a part of my model on a GPU, while having the inital and final parts of the model on the CPU. To my understanding, such a function cannot be jitted (at the moment). However, the conjugate gradient implementation both explicitly and implicitly (through the call of jax.lax.while_loop) jits the model evaluation. The model is then executed fully on the CPU. Trying to avoid this compilation by removing this jit-call in conjugate_gradient leads to an error in the following while_loop
here do to the some of the arrays being on different devices.
As a sanity check I have tested jax.linearize
and jax.linear_transpose
on the model, which works without any (apparent) issues. So I have hopes that optimize_kl
could work with such a model, if all jit-compilations of the full model call would be avoided.
If I understand this issue correctly, JAX is working on an API that would allow this kind of behavior in a jitted function, however I couldn't find any current information on that feature. So while this might resolve the issue, it is unclear when it would be available.
As there are already flags like kl_jit
and residual_jit
, would it be possible to add a flag that prevents any jit-compilation within the optimization routine? Are there any other places where an implicit jit-compilation takes place that I am missing?
If it would help I could provide a minimal working example for this issue.