8000 Nested `jax.jit` hangs on second call when `jax.Array` stored in static pytree metadata · Issue #28311 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Nested jax.jit hangs on second call when jax.Array stored in static pytree metadata #28311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
danieldjohnson opened this issue Apr 26, 2025 · 2 comments
< 8000 div class="gh-header-shadow color-shadow-small js-notification-shelf-offset-top">
Labels
bug Something isn't working

Comments

@danieldjohnson
Copy link
Contributor

Description

When using a combination of nested JIT calls and passing a constant scalar jax.Array as non-pytree metadata, the second call causes jax.jit to hang indefinitely:

import dataclasses
import jax
import jax.numpy as jnp

@dataclasses.dataclass
class Indexer:
    index: int

jax.tree_util.register_dataclass(Indexer, data_fields=[], meta_fields=["index"])

@jax.jit
def slice_it(x: jax.Array, indexer: Indexer):
    print("Tracing slice_it")
    jax.debug.print("Running slice_it")
    return x[:indexer.index]
        
def outer_fn():
    print("Running outer_fn")
    indexer = Indexer(jnp.array(3))  # <- oops, passed a JAX array instead of an int

    @jax.jit
    def inner_fn(x: jax.Array):
        print("Tracing inner_fn")
        jax.debug.print("Running inner_fn")
        return slice_it(x, indexer)

    return inner_fn


test_fn = outer_fn()
print("First call result:", test_fn(jnp.arange(30)))
test_fn = outer_fn()
print("Second call result:", test_fn(jnp.arange(30)))

The printed output here is:

Running outer_fn
Tracing inner_fn
Tracing slice_it
Running inner_fn
Running slice_it
First call result: [0 1 2]
Running outer_fn
Tracing inner_fn

and then the Python interpreter hangs forever.

I would have expected either:

  • an informative error message telling me what I did wrong
  • this to work correctly and return the same result on both calls

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.37
jaxlib: 0.4.36
numpy:  2.2.0
python: 3.12.5 (main, Aug 14 2024, 04:32:18) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Daniels-Laptop.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')
@mattjj
Copy link
Collaborator
mattjj commented Apr 26, 2025

Daniel, good to hear from you! Thanks for the clear repro.

This isn't really relevant to the bug but I thought I'd mention: usually we discourage putting jax.Arrays in pytree metadata because, like numpy.ndarrays, they aren't hashable.

A hang is weird though... and I was able to repro the behavior at HEAD (i.e. not just using the now-ancient jax==0.4.37 / jaxlib==0.4.36).

@jakevdp
Copy link
Collaborator
jakevdp commented May 9, 2025

Cross-referencing a related issue: #24204 (no hang, but this is also an issue with including non-hashable objects in pytree metadata)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants
0