-
Notifications
You must be signed in to change notification settings - Fork 3k
[explicit sharding] sharding rule for scatter is not implemented #28111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Labels
bug
Something isn't working
Comments
I have a fix. I'll push it in some time :) |
copybara-service bot
pushed a commit
that referenced
this issue
Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes #28111 PiperOrigin-RevId: 749059171
copybara-service bot
pushed a commit
that referenced
this issue
Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes #28111 PiperOrigin-RevId: 749059171
copybara-service bot
pushed a commit
that referenced
this issue
Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes #28111 PiperOrigin-RevId: 749059171
Thanks a lot for the prompt fix! Does the fix also works under vmap as my second snippet? |
Yes! |
copybara-service bot
pushed a commit
that referenced
this issue
Apr 18, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes #28111 PiperOrigin-RevId: 749059171
charleshofer
pushed a commit
to ROCm/jax
that referenced
this issue
Apr 30, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes jax-ml#28111 PiperOrigin-RevId: 749089846
charleshofer
pushed a commit
to ROCm/jax
that referenced
this issue
May 1, 2025
…n `out_sharding` argument in `set`, use the input array's `sharding`. Since this is an update, after `.set`, the input array's sharding should be preserved. Fixes jax-ml#28111 PiperOrigin-RevId: 749089846
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
I'm trying to convert the following snippet to explicit sharding mode
The snippet fails with error
Not related to the bug, but also relevant: I'm also trying to implement the same snippet (with explicit sharding) under vmap. My current attempt is the following, but it does not work and it's quite unclear to me how to make it work
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: