10000 [sharding-in-types] Bug in jnp.repeat · Issue #28538 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[sharding-in-types] Bug in jnp.repeat #28538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
PhilipVinc opened this issue May 6, 2025 · 1 comment · Fixed by #28600
Closed

[sharding-in-types] Bug in jnp.repeat #28538

PhilipVinc opened this issue May 6, 2025 · 1 comment · Fixed by #28600
Assignees
Labels
bug Something isn't working

Comments

@PhilipVinc
Copy link
Contributor

Description

See the following MWE: calling jnp.repeat with a explicit mesh set, even if I don't declare any sharding specification in the input array, leads to an error

import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
from jax.experimental.shard import reshard, auto_axes, explicit_axes

jax.config.update("jax_num_cpu_devices", 4)

mesh = jax.make_mesh((4,), ("S", ),axis_types=(AxisType.Explicit,))

print(f"Current mesh is: {get_abstract_mesh()}")

a = jnp.eye(3)
jnp.repeat(a,np.array((2,2,2,)) - 1,axis=0)

set_mesh(mesh)
print(f"Current mesh is: {get_abstract_mesh()}")
jnp.repeat(a,np.array((2,2,2,)) - 1,axis=0)

errors on the second jnp.repeat call with

ShardingTypeError: sharding rule for scatter-add is not implemented. Please file a bug at https://github.com/jax-ml/jax/issues. You can work around this error by dropping that operation into full auto sharding mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`

I think there are two issues here: one is the missing sharding rule for scatter-add, which sounds like #28111 which I thought had been addressed. The other is that jnp.repeat should not shard if not requested even if a mesh is set?

This is on jax installed from git main branch ~yesterday night.

System info (python version, jaxlib version, accelerator, etc.)

In [3]: import jax; jax.print_environment_info()
jax:    0.6.1.dev20250505+13ca7002b
jaxlib: 0.6.0
numpy:  2.2.5
python: 3.13.3 (main, Apr  9 2025, 03:47:57) [Clang 20.1.0 ]
device info: cpu-4, 4 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.4.0', version='Darwin Kernel Version 24.4.0: Wed Mar 19 21:16:34 PDT 2025; root:xnu-11417.101.15~1/RELEASE_ARM64_T6000', machine='arm64')
@PhilipVinc PhilipVinc added the bug Something isn't working label May 6, 2025
@yashk2810 yashk2810 self-assigned this May 6, 2025
@yashk2810
Copy link
Collaborator
yashk2810 commented May 7, 2025

I am working on this but until then, you can workaround this by doing the following:

jax.experimental.shard.auto_axes(lambda x: jnp.repeat(x, np.array((2,2,2,)) - 1, axis=0), out_sharding=...)(jnp.eye(3))

copybara-service bot pushed a commit that referenced this issue May 8, 2025
…ng is provided.

In cases where axis is None or the input is sharded on the `axis` we are going to repeat on.

If the input is not sharded on the repeat axis, forward the input sharding to the output.

Fixes #28538

PiperOrigin-RevId: 756074364
copybara-service bot pushed a commit that referenced this issue May 8, 2025
…ng is provided.

In cases where axis is None or the input is sharded on the `axis` we are going to repeat on.

If the input is not sharded on the repeat axis, forward the input sharding to the output.

Fixes #28538

PiperOrigin-RevId: 756074364
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants
0