8000 `lax` collective axis names not recognized with JIT sharding · Issue #28666 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

lax collective axis names not recognized with JIT sharding #28666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jaketae opened this issue May 10, 2025 · 1 comment
Closed

lax collective axis names not recognized with JIT sharding #28666

jaketae opened this issue May 10, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@jaketae
Copy link
jaketae commented May 10, 2025

Description

TL;DR: Is there a way to use lax primitives with jit's sharding API?

I'm trying to migrate pmap to jit using sharding arguments. However, lax collectives like psum and pmean seem to fail because they do not recognize axis_name used in those primitives. I'm using flax.nnx, but I observe the same with this JAX-only reproducer:

import jax, jax.numpy as jnp, numpy as np
from jax import lax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from functools import partial

mesh = Mesh(np.asarray(jax.devices()), ('data',))
in_sharding = NamedSharding(mesh, P('data',))
out_sharding = NamedSharding(mesh, P())

@partial(           
    jax.jit,
    in_shardings=in_sharding,   
    out_shardings=out_sharding,       
)
def psum_jit(x):
    return lax.psum(x, 'data') 

# with mesh: (doesn't help)
y = psum_jit(jnp.arange(8, dtype=jnp.float32))
print(y)                
jax.debug.visualize_array_sharding(y)

Output:

  File "/Users/jaketae/Developer/github/jax-playground/src/jit_test.py", line 19, in allreduce_mean
    return lax.psum(x, 'data') 
NameError: unbound axis name: data

I'm not too familar with JAX internals, but the mesh context manager seems to makes axes available via _ThreadResourcesLocalState, whereas lax collectives work with AxisEnv within TracingContext. I haven't quite understood if/how they are related.

Interestingly, using shard_map with jit resolves the issue:

import os, jax, jax.numpy as jnp, numpy as np
from jax import lax, jit
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P

mesh = Mesh(np.asarray(jax.devices()), ('data',))

def psum(x):
    return lax.psum(x, axis_name='data')

psum_shard_map = shard_map(
    psum,
    mesh=mesh,
    in_specs=P('data',),  
    out_specs=P()   
)
psum_jit = jit(psum_shard_map) 

y = psum_jit(jnp.arange(8, dtype=jnp.float32))
print(y)        
jax.debug.visualize_array_sharding(y)

Other failed attempts (just jit, no shard_map):

  • Use with mesh context;
  • Use with jax.sharding.use_mesh(mesh) context;
  • Use pjit (aware that this is largely deprecated and subsumed by jit)
  • Run on another platform with a different JAX version (indicated below).

Is this behavior expected? Thanks in advance.

System info (python version, jaxlib version, accelerator, etc.)

Tested on two different systems: 8xH100, Apple M2, with slightly different package versions. They both yield the same error.

>>> import jax; jax.print_environment_info()
jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.26.4
python: 3.12.4 (main, Jul 29 2024, 21:12:39) [GCC 11.4.1 20230605 (Red Hat 11.4.1-2)]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='lg1n0016.optiver.us', release='5.14.0-362.8.1.el9_3.x86_64', version='#1 SMP PREEMPT_DYNAMIC Tue Oct 3 11:12:36 EDT 2023', machine='x86_64')

$ nvidia-smi
Fri May  9 23:48:07 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:1C:00.0 Off |                    0 |
| N/A   25C    P0            112W /  700W |    4698MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:2A:00.0 Off |                    0 |
| N/A   27C    P0            115W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:48:00.0 Off |                    0 |
| N/A   26C    P0            114W /  700W |     542MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:51:00.0 Off |                    0 |
| N/A   26C    P0            112W /  700W |     560MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000001:1C:00.0 Off |                    0 |
| N/A   24C    P0            112W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000001:25:00.0 Off |                    0 |
| N/A   25C    P0            113W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000001:48:00.0 Off |                    0 |
| N/A   26C    P0            111W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000001:51:00.0 Off |                    0 |
| N/A   26C    P0            114W /  700W |     660MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
>>> import jax; jax.print_environment_info()
jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.3
python: 3.10.12 (main, Aug  8 2023, 19:18:21) [Clang 14.0.3 (clang-1403.0.22.14.1)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Jakes-MBP-6.attlocal.net', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:55:06 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6020', machine='arm64')
@jaketae jaketae added the bug Something isn't working label May 10, 2025
@jaketae jaketae changed the title lax collectives with sharding lax collectives with JIT sharding May 10, 2025
@jaketae jaketae changed the title lax collectives with JIT sharding lax collective axis names not recognized with JIT sharding May 10, 2025
@yashk2810
Copy link
Collaborator

Yes, this is working as expected.

jax.lax collectives only work under shard_map. See https://docs.jax.dev/en/latest/notebooks/shard_map.html for more information!

I am closing the issue but feel free to reopen if you have more questions :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants
0