8000 Nested `jax.jit` hangs on second call when `jax.Array` stored in static pytree metadata · Issue #28311 · jax-ml/jax · GitHub 8000
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Nested jax.jit hangs on second call when jax.Array stored in static pytree metadata #28311
Open
@danieldjohnson

Description

@danieldjohnson

Description

When using a combination of nested JIT calls and passing a constant scalar jax.Array as non-pytree metadata, the second call causes jax.jit to hang indefinitely:

import dataclasses
import jax
import jax.numpy as jnp

@dataclasses.dataclass
class Indexer:
    index: int

jax.tree_util.register_dataclass(Indexer, data_fields=[], meta_fields=["index"])

@jax.jit
def slice_it(x: jax.Array, indexer: Indexer):
    print("Tracing slice_it")
    jax.debug.print("Running slice_it")
    return x[:indexer.index]
        
def outer_fn():
    print("Running outer_fn")
    indexer = Indexer(jnp.array(3))  # <- oops, passed a JAX array instead of an int

    @jax.jit
    def inner_fn(x: jax.Array):
        print("Tracing inner_fn")
        jax.debug.print("Running inner_fn")
        return slice_it(x, indexer)

    return inner_fn


test_fn = outer_fn()
print("First call result:", test_fn(jnp.arange(30)))
test_fn = outer_fn()
print("Second call result:", test_fn(jnp.arange(30)))

The printed output here is:

Running outer_fn
Tracing inner_fn
Tracing slice_it
Running inner_fn
Running slice_it
First call result: [0 1 2]
Running outer_fn
Tracing inner_fn

and then the Python interpreter hangs forever.

I would have expected either:

  • an informative error message telling me what I did wrong
  • this to work correctly and return the same result on both calls

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.37
jaxlib: 0.4.36
numpy:  2.2.0
python: 3.12.5 (main, Aug 14 2024, 04:32:18) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Daniels-Laptop.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0