8000 [explicit sharding] sharding rule for scatter is not implemented · Issue #28111 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[explicit sharding] sharding rule for scatter is not implemented #28111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
PhilipVinc opened this issue Apr 18, 2025 · 3 comments · Fixed by #28123
Closed

[explicit sharding] sharding rule for scatter is not implemented #28111

PhilipVinc opened this issue Apr 18, 2025 · 3 comments · Fixed by #28123
Assignees
Labels
bug Something isn't working

Comments

@PhilipVinc
Copy link
Contributor
PhilipVinc commented Apr 18, 2025

Description

I'm trying to convert the following snippet to explicit sharding mode

from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PositionalSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
from jax.experimental.shard import reshard, auto_axes, explicit_axes

jax.config.update('jax_num_cpu_devices', 2)

mesh = jax.make_mesh((2,), ("i"), axis_types=(AxisType.Explicit,))
set_mesh(mesh)

x = np.random.uniform(size=(jax.device_count() * 2, 3))
i = np.random.randint(0, x.shape[1], len(x))
j = np.random.randint(0, x.shape[1], len(x))
x = reshard(x, (P("i")))
i = reshard(i, (P("i")))
j = reshard(j, (P("i")))

@jax.jit
def f1(x, i, j):
    x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding)
    print("x_a_j", jax.typeof(x_a_j), jax.typeof(x_a_j).sharding)
    return x.at[:, i].set(x_a_j)

f1(x,i,j)

The snippet fails with error

In [2]: f1(x,i,j)
x_a_j ShapedArray(float32[4@i,4]) NamedSharding(mesh=AbstractMesh('i': 2, axis_types=(Explicit,)), spec=PartitionSpec('i', None))
---------------------------------------------------------------------------
ShardingTypeError                         Traceback (most recent call last)
Cell In[2], line 1
----> 1 f1(x,i,j)

    [... skipping hidden 14 frame]

Cell In[1], line 33, in f1(x, i, j)
     31 x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding)
     32 print("x_a_j", jax.typeof(x_a_j), jax.typeof(x_a_j).sharding)
---> 33 return x.at[:, i].set(x_a_j)

ShardingTypeError: sharding rule for scatter is not implemented. Please file a bug at https://github.com/jax-ml/jax/issues. You can work around this error by dropping that operation into full auto sharding mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`

Not related to the bug, but also relevant: I'm also trying to implement the same snippet (with explicit sharding) under vmap. My current attempt is the following, but it does not work and it's quite unclear to me how to make it work

@jax.jit
@jax.vmap
def f2(x, i, j):
    x_j = x.at[j].get(out_sharding=jax.typeof(x).sharding)
    return x.at[i].set(x_j)

f2(x,i,j)

System info (python version, jaxlib version, accelerator, etc.)

In [3]: import jax; jax.print_environment_info()
jax:    0.6.0
jaxlib: 0.6.0
numpy:  2.2.4
python: 3.13.1 (main, Dec 19 2024, 14:22:59) [Clang 18.1.8 ]
device info: cpu-2, 2 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.4.0', version='Darwin Kernel Version 24.4.0: Wed Mar 19 21:16:34 PDT 2025; root:xnu-11417.101.15~1/RELEASE_ARM64_T6000', machine='arm64')
@PhilipVinc PhilipVinc added the bug Something isn't working label Apr 18, 2025
@yashk2810 yashk2810 self-assigned this Apr 18, 2025
@yashk2810
Copy link
Collaborator
yashk2810 commented Apr 18, 2025

I have a fix. I'll push it in some time :)

copybara-service bot pushed a commit that referenced this issue Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.

Fixes #28111

PiperOrigin-RevId: 749059171
copybara-service bot pushed a commit that referenced this issue Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.

Fixes #28111

PiperOrigin-RevId: 749059171
copybara-service bot pushed a commit that referenced this issue Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.

Fixes #28111

PiperOrigin-RevId: 749059171
@PhilipVinc
Copy link
Contributor Author

Thanks a lot for the prompt fix!

Does the fix also works under vmap as my second snippet?

@yashk2810
Copy link
Collaborator

Yes!

copybara-service bot pushed a commit that referenced this issue Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.

Fixes #28111

PiperOrigin-RevId: 749059171
charleshofer pushed a commit to ROCm/jax that referenced this issue Apr 30, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.

Fixes jax-ml#28111

PiperOrigin-RevId: 749089846
charleshofer pushed a commit to ROCm/jax that referenced this issue May 1, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved.

Fixes jax-ml#28111

PiperOrigin-RevId: 749089846
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
0