Open
Description
Description
When using a combination of nested JIT calls and passing a constant scalar jax.Array
as non-pytree metadata, the second call causes jax.jit
to hang indefinitely:
import dataclasses
import jax
import jax.numpy as jnp
@dataclasses.dataclass
class Indexer:
index: int
jax.tree_util.register_dataclass(Indexer, data_fields=[], meta_fields=["index"])
@jax.jit
def slice_it(x: jax.Array, indexer: Indexer):
print("Tracing slice_it")
jax.debug.print("Running slice_it")
return x[:indexer.index]
def outer_fn():
print("Running outer_fn")
indexer = Indexer(jnp.array(3)) # <- oops, passed a JAX array instead of an int
@jax.jit
def inner_fn(x: jax.Array):
print("Tracing inner_fn")
jax.debug.print("Running inner_fn")
return slice_it(x, indexer)
return inner_fn
test_fn = outer_fn()
print("First call result:", test_fn(jnp.arange(30)))
test_fn = outer_fn()
print("Second call result:", test_fn(jnp.arange(30)))
The printed output here is:
Running outer_fn
Tracing inner_fn
Tracing slice_it
Running inner_fn
Running slice_it
First call result: [0 1 2]
Running outer_fn
Tracing inner_fn
and then the Python interpreter hangs forever.
I would have expected either:
- an informative error message telling me what I did wrong
- this to work correctly and return the same result on both calls
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.37
jaxlib: 0.4.36
numpy: 2.2.0
python: 3.12.5 (main, Aug 14 2024, 04:32:18) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Daniels-Laptop.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')