Nested jax.jit
hangs on second call when jax.Array
stored in static pytree metadata
#28311
Labels
bug
Something isn't working
Description
When using a combination of nested JIT calls and passing a constant scalar
jax.Array
as non-pytree metadata, the second call causesjax.jit
to hang indefinitely:The printed output here is:
and then the Python interpreter hangs forever.
I would have expected either:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: