Closed
Description
Description
I'm trying to convert the following snippet to explicit sharding mode
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PositionalSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
from jax.experimental.shard import reshard, auto_axes, explicit_axes
jax.config.update('jax_num_cpu_devices', 2)
mesh = jax.make_mesh((2,), ("i"), axis_types=(AxisType.Explicit,))
set_mesh(mesh)
x = np.random.uniform(size=(jax.device_count() * 2, 3))
i = np.random.randint(0, x.shape[1], len(x))
j = np.random.randint(0, x.shape[1], len(x))
x = reshard(x, (P("i")))
i = reshard(i, (P("i")))
j = reshard(j, (P("i")))
@jax.jit
def f1(x, i, j):
x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding)
print("x_a_j", jax.typeof(x_a_j), jax.typeof(x_a_j).sharding)
return x.at[:, i].set(x_a_j)
f1(x,i,j)
The snippet fails with error
In [2]: f1(x,i,j)
x_a_j ShapedArray(float32[4@i,4]) NamedSharding(mesh=AbstractMesh('i': 2, axis_types=(Explicit,)), spec=PartitionSpec('i', None))
---------------------------------------------------------------------------
ShardingTypeError Traceback (most recent call last)
Cell In[2], line 1
----> 1 f1(x,i,j)
[... skipping hidden 14 frame]
Cell In[1], line 33, in f1(x, i, j)
31 x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding)
32 print("x_a_j", jax.typeof(x_a_j), jax.typeof(x_a_j).sharding)
---> 33 return x.at[:, i].set(x_a_j)
ShardingTypeError: sharding rule for scatter is not implemented. Please file a bug at https://github.com/jax-ml/jax/issues. You can work around this error by dropping that operation into full auto sharding mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`
Not related to the bug, but also relevant: I'm also trying to implement the same snippet (with explicit sharding) under vmap. My current attempt is the following, but it does not work and it's quite unclear to me how to make it work
@jax.jit
@jax.vmap
def f2(x, i, j):
x_j = x.at[j].get(out_sharding=jax.typeof(x).sharding)
return x.at[i].set(x_j)
f2(x,i,j)
System info (python version, jaxlib version, accelerator, etc.)
In [3]: import jax; jax.print_environment_info()
jax: 0.6.0
jaxlib: 0.6.0
numpy: 2.2.4
python: 3.13.1 (main, Dec 19 2024, 14:22:59) [Clang 18.1.8 ]
device info: cpu-2, 2 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.4.0', version='Darwin Kernel Version 24.4.0: Wed Mar 19 21:16:34 PDT 2025; root:xnu-11417.101.15~1/RELEASE_ARM64_T6000', machine='arm64')