You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
See the following MWE: calling jnp.repeat with a explicit mesh set, even if I don't declare any sharding specification in the input array, leads to an error
I think there are two issues here: one is the missing sharding rule for scatter-add, which sounds like #28111 which I thought had been addressed. The other is that jnp.repeat should not shard if not requested even if a mesh is set?
This is on jax installed from git main branch ~yesterday night.
System info (python version, jaxlib version, accelerator, etc.)
…ng is provided.
In cases where axis is None or the input is sharded on the `axis` we are going to repeat on.
If the input is not sharded on the repeat axis, forward the input sharding to the output.
Fixes#28538
PiperOrigin-RevId: 756074364
…ng is provided.
In cases where axis is None or the input is sharded on the `axis` we are going to repeat on.
If the input is not sharded on the repeat axis, forward the input sharding to the output.
Fixes#28538
PiperOrigin-RevId: 756074364
Description
See the following MWE: calling
jnp.repeat
with a explicit mesh set, even if I don't declare any sharding specification in the input array, leads to an errorerrors on the second
jnp.repeat
call withI think there are two issues here: one is the missing sharding rule for scatter-add, which sounds like #28111 which I thought had been addressed. The other is that
jnp.repeat
should not shard if not requested even if a mesh is set?This is on jax installed from git main branch ~yesterday night.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: