8000 torch.nn.functional.one_hot has inconsistent behavior between eager and torch.compile when num_classes=0 · Issue #146274 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
torch.nn.functional.one_hot has inconsistent behavior between eager and torch.compile when num_classes=0 #146274
@meetmul

Description

@meetmul

🐛 Describe the bug

When num_classes=0, torch.nn.functional.one_hot will throw Class values must be smaller than num_classes. under eager but outputs empty tensor under torch.compile.

import torch
f = torch.nn.functional.one_hot
a = torch.arange(0, 5) % 3  # [0,1,2,0,1]
num_classes = 0
try:
    torch.nn.functional.one_hot(a,num_classes)
except Exception as e:
    print("Error on eager: ", str(e))
res = torch.compile(torch.nn.functional.one_hot)(a,num_classes)
print("Output under torch.compile: ", res)

Error logs

Error on eager: Class values must be smaller than num_classes.
Output under torch.compile: tensor([], size=(5, 0), dtype=torch.int64)

Versions

[pip3] numpy==1.26.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] optree==0.13.1
[pip3] torch==2.5.1
[pip3] triton==3.1.0
[conda] numpy 1.26.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] optree 0.13.1 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi

cc @chauhang @penguinwu @eellison @zou3519 @bdhirsh @yf225

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0