8000 `vmap` not working on `torch.arange`, `torch.scalar_tensor`, and `torch.ones` · Issue #152295 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
vmap not working on torch.arange, torch.scalar_tensor, and torch.ones #152295
Open
@defaultd661

Description

@defaultd661

🐛 Describe the bug

torch.arange

To Reproduce

import torch
from functools import partial


def test_bug():
    batched_arange = torch.vmap(partial(torch.arange, step=1))
    start = torch.tensor([1, 2, 3], dtype=torch.int64)
    end = torch.tensor([25, 26, 27], dtype=torch.int64)
    batched_arange(start, end)


if __name__ == '__main__':
    test_bug()

Output

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

torch.ones

To Reproduce

import torch
from functools import partial


def test_bug():
    batched_shapes = torch.tensor([[2, 3], [3, 4], [4, 5]], dtype=torch.int64)
    def ones_from_shape(shape):
        return torch.ones(shape[0], shape[1])
    batched_ones = torch.vmap(ones_from_shape)
    batched_ones(batched_shapes)


if __name__ == '__main__':
    test_bug()

Output

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

torch.scalar_tensor

To Reproduce

import torch


def test_bug():
    batched_scalar = torch.vmap(torch.scalar_tensor)
    values = torch.tensor([1.0, 2.0, 3.0])
    batched_scalar(values)

if __name__ == '__main__':
    test_bug()

Output

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

Versions

PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

cc @zou3519 @Chillee @samdow @kshitij12345

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: functorchPertaining to torch.func or pytorch/functorchmodule: vmaptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0