8000 [explicit sharding] sharding rule for scatter is not implemented · Issue #28111 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[explicit sharding] sharding rule for scatter is not implemented #28111
Closed
@PhilipVinc

Description

@PhilipVinc

Description

I'm trying to convert the following snippet to explicit sharding mode

from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PositionalSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
from jax.experimental.shard import reshard, auto_axes, explicit_axes

jax.config.update('jax_num_cpu_devices', 2)

mesh = jax.make_mesh((2,), ("i"), axis_types=(AxisType.Explicit,))
set_mesh(mesh)

x = np.random.uniform(size=(jax.device_count() * 2, 3))
i = np.random.randint(0, x.shape[1], len(x))
j = np.random.randint(0, x.shape[1], len(x))
x = reshard(x, (P("i")))
i = reshard(i, (P("i")))
j = reshard(j, (P("i")))

@jax.jit
def f1(x, i, j):
    x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding)
    print("x_a_j", jax.typeof(x_a_j), jax.typeof(x_a_j).sharding)
    return x.at[:, i].set(x_a_j)

f1(x,i,j)

The snippet fails with error

In [2]: f1(x,i,j)
x_a_j ShapedArray(float32[4@i,4]) NamedSharding(mesh=AbstractMesh('i': 2, axis_types=(Explicit,)), spec=PartitionSpec('i', None))
---------------------------------------------------------------------------
ShardingTypeError                         Traceback (most recent call last)
Cell In[2], line 1
----> 1 f1(x,i,j)

    [... skipping hidden 14 frame]

Cell In[1], line 33, in f1(x, i, j)
     31 x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding)
     32 print("x_a_j", jax.typeof(x_a_j), jax.typeof(x_a_j).sharding)
---> 33 return x.at[:, i].set(x_a_j)

ShardingTypeError: sharding rule for scatter is not implemented. Please file a bug at https://github.com/jax-ml/jax/issues. You can work around this error by dropping that operation into full auto sharding mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`

Not related to the bug, but also relevant: I'm also trying to implement the same snippet (with explicit sharding) under vmap. My current attempt is the following, but it does not work and it's quite unclear to me how to make it work

@jax.jit
@jax.vmap
def f2(x, i, j):
    x_j = x.at[j].get(out_sharding=jax.typeof(x).sharding)
    return x.at[i].set(x_j)

f2(x,i,j)

System info (python version, jaxlib version, accelerator, etc.)

In [3]: import jax; jax.print_environment_info()
jax:    0.6.0
jaxlib: 0.6.0
numpy:  2.2.4
python: 3.13.1 (main, Dec 19 2024, 14:22:59) [Clang 18.1.8 ]
device info: cpu-2, 2 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.4.0', version='Darwin Kernel Version 24.4.0: Wed Mar 19 21:16:34 PDT 2025; root:xnu-11417.101.15~1/RELEASE_ARM64_T6000', machine='arm64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0