lax
collective axis names not recognized with JIT sharding
#28666
Labels
bug
Something isn't working
Description
TL;DR: Is there a way to use
lax
primitives withjit
's sharding API?I'm trying to migrate
pmap
tojit
using sharding arguments. However,lax
collectives likepsum
andpmean
seem to fail because they do not recognizeaxis_name
used in those primitives. I'm usingflax.nnx
, but I observe the same with this JAX-only reproducer:Output:
I'm not too familar with JAX internals, but the mesh context manager seems to makes axes available via
_ThreadResourcesLocalState
, whereaslax
collectives work withAxisEnv
withinTracingContext
. I haven't quite understood if/how they are related.Interestingly, using
shard_map
withjit
resolves the issue:Other failed attempts (just
jit
, noshard_map
):with mesh
context;with jax.sharding.use_mesh(mesh)
context;pjit
(aware that this is largely deprecated and subsumed byjit
)Is this behavior expected? Thanks in advance.
System info (python version, jaxlib version, accelerator, etc.)
Tested on two different systems: 8xH100, Apple M2, with slightly different package versions. They both yield the same error.
The text was updated successfully, but these errors were encountered: