Closed as duplicate of#71673
Description
🐛 Bug
import torch
x = torch.rand(2, 3, 5)
b = torch.tensor([[True, False, True], [False, True, True]]) # shape (2, 3)
x.numpy()[b.numpy(), 0] # shape (4,)
x[b, 0] # raises
IndexError: The shape of the mask [2, 3] at index 1 does not match the shape of the indexed tensor [2, 5] at index 1
Expected behavior
No error; for a tensor of shape (4,) to be returned.
Or of course, assigned to, if it's x[b, 0] = ...
instead. This was the case I actually ran into whilst developing code.
Environment
- PyTorch Version (e.g., 1.0): 1.9.0
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
, source): conda - Python version: 3.9
- Any other relevant information: Numpy indexing tested with numpy version 1.20.3