Open
Description
🐛 Describe the bug
torch.arange
To Reproduce
import torch
from functools import partial
def test_bug():
batched_arange = torch.vmap(partial(torch.arange, step=1))
start = torch.tensor([1, 2, 3], dtype=torch.int64)
end = torch.tensor([25, 26, 27], dtype=torch.int64)
batched_arange(start, end)
if __name__ == '__main__':
test_bug()
Output
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
torch.ones
To Reproduce
import torch
from functools import partial
def test_bug():
batched_shapes = torch.tensor([[2, 3], [3, 4], [4, 5]], dtype=torch.int64)
def ones_from_shape(shape):
return torch.ones(shape[0], shape[1])
batched_ones = torch.vmap(ones_from_shape)
batched_ones(batched_shapes)
if __name__ == '__main__':
test_bug()
Output
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
torch.scalar_tensor
To Reproduce
import torch
def test_bug():
batched_scalar = torch.vmap(torch.scalar_tensor)
values = torch.tensor([1.0, 2.0, 3.0])
batched_scalar(values)
if __name__ == '__main__':
test_bug()
Output
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
Versions
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A