8000 [sharding-in-types] `jax.lax.map(...batch_size=)` bug when using Explicit shapes · Issue #29195 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[sharding-in-types] jax.lax.map(...batch_size=) bug when using Explicit shapes #29195
Closed
@PhilipVinc

Description

@PhilipVinc

Description

MWE below:

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, AxisType
from jax.experimental.shard import reshard, explicit_axes

# Setup: 2 CPU devices
jax.config.update("jax_num_cpu_devices", 2)
devices = np.array(jax.devices())
mesh = jax.make_mesh((2,),("s",), axis_types=(AxisType.Explicit,),)
jax.sharding.set_mesh(mesh) # Set this as the default mesh for jax.

def simple_func(w, x):
    return jnp.sum(w * x, axis=-1)

# Make inputs
w = jnp.array([1.0, 2.0, 3.0, 4.0])
x = jnp.ones((5, 2, 4))

# Setup sharding
sharded_x_r = reshard(x, P(None, "s", None))

jax.lax.map(lambda _x: simple_func(w, _x), sharded_x_r, batch_size=2)

crashes with

ShardingTypeError: Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for the gather indexing.

the 'issue' is in this line

scan_leaves.append(leaf[:total_batch_elems].reshape(

which should use the .at[].get(...) notation

cc @yashk2810 who recently fixed related issues in jax.lax.map. This issue is similar to #29164 (reply in thread) but emerges if the remainder is not None.

System info (python version, jaxlib version, accelerator, etc.)

current master as of 3 June 2025 morning

In [1]: import jax; jax.print_environment_info()
   ...:
jax:    0.6.2.dev20250603+2193c59fb
jaxlib: 0.6.1
numpy:  2.2.6
python: 3.13.3 (main, Apr  9 2025, 03:47:57) [Clang 20.1.0 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:49 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6000', machine='arm64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0