Description
🐛 Describe the bug
Similar to #154881
import torch
x = torch.ones([5, 4, 3, 2, 1], device="mps")
torch.ops.aten.topk(x, k=5, dim=0)
Error message:
/AppleInternal/Library/BuildRoots/01adf19d-fba1-11ef-a947-f2a857e00a32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArraySort.mm:208: failed assertion `(null)" Axis = 4. This class only supports axis = 0, 1, 2, 3
'
zsh: abort
Versions
torch==2.7.0