Closed
Description
Description
MWE below:
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, AxisType
from jax.experimental.shard import reshard, explicit_axes
# Setup: 2 CPU devices
jax.config.update("jax_num_cpu_devices", 2)
devices = np.array(jax.devices())
mesh = jax.make_mesh((2,),("s",), axis_types=(AxisType.Explicit,),)
jax.sharding.set_mesh(mesh) # Set this as the default mesh for jax.
def simple_func(w, x):
return jnp.sum(w * x, axis=-1)
# Make inputs
w = jnp.array([1.0, 2.0, 3.0, 4.0])
x = jnp.ones((5, 2, 4))
# Setup sharding
sharded_x_r = reshard(x, P(None, "s", None))
jax.lax.map(lambda _x: simple_func(w, _x), sharded_x_r, batch_size=2)
crashes with
ShardingTypeError: Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for the gather indexing.
the 'issue' is in this line
jax/jax/_src/lax/control_flow/loops.py
Line 2519 in 2193c59
which should use the .at[].get(...)
notation
cc @yashk2810 who recently fixed related issues in jax.lax.map
. This issue is similar to #29164 (reply in thread) but emerges if the remainder is not None.
System info (python version, jaxlib version, accelerator, etc.)
current master as of 3 June 2025 morning
In [1]: import jax; jax.print_environment_info()
...:
jax: 0.6.2.dev20250603+2193c59fb
jaxlib: 0.6.1
numpy: 2.2.6
python: 3.13.3 (main, Apr 9 2025, 03:47:57) [Clang 20.1.0 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:49 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6000', machine='arm64')