8000 Operations that fail under `torch.export.export`(`torch.autograd.grad`) -> `torch.compile` · Issue #146719 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Operations that fail under torch.export.export(torch.autograd.grad) -> torch.compile #146719
Open
@cw-tan

Description

@cw-tan

🐛 Describe the bug

Not completely a torch.compile problem, more of a torch.export.export problem when used with torch.autograd.grad. We have seen minimal working examples of torch.export.export working for code that includes x.requires_grad() ... torch.autograd.grad(y, [x]), but some operations don't get exported correctly. Here is a minimal script that fails on PyTorch nightly 2.7.0 for a bunch of operations found so far (and a minimal one that succeeds).

import torch


class Model(torch.nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(num_channels))

    def forward(self, x, y):
        x.requires_grad_()
        # this works
        # z = (x * y * self.weight).square().sum()

        # 1. `sqrt` fails
        z = (x * y * self.weight).square().sum().sqrt()

        # 2. `rsqrt` fails
        # z = (x * y * self.weight).square().sum().rsqrt()

        # 3. `exp` fails
        # z = (x * y * self.weight).square().sum().exp()

        # 4. `torch.linalg.norm` fails
        # z = torch.linalg.norm(x * self.weight, dim=-1).sum()

        # 5. using `matmul` with broadcasting fails
        # note that this doesn't use `self.weight` and would error out anyway if we reach `loss.backward()`, but is here to show the error during `export`
        # b1n, bn1 -> b11
        # z = torch.matmul(x.unsqueeze(1), y.unsqueeze(2)).square().sum()

        # 6. using `torch.nn.functional.linear` fails
        # z = torch.nn.functional.linear(x, torch.outer(self.weight, self.weight), None).sum()

        grad = torch.autograd.grad(z, [x])[0]
        return grad


device = "cuda"
num_batch = 512
num_channels = 256

x = torch.randn(num_batch, num_channels, dtype=torch.float32, device=device)
y = torch.randn(num_batch, num_channels, dtype=torch.float32, device=device)

model = Model(num_channels).to(device=device)
eager_out = model(x, y)
print(eager_out)

batch_dim = torch.export.Dim("batch", min=1, max=1024)
exported = torch.export.export(
    model,
    (
        x,
        y,
    ),
    strict=False,
    dynamic_shapes={"x": {0: batch_dim}, "y": {0: batch_dim}},
)

model = torch.compile(exported.module())
out = model(x, y)
loss = out.square().mean()
loss.backward()

print(model.weight.grad)

Error logs

The error signature for all these cases look like the following. I'm guessing the backward pass of these problematic operations are not very export friendly?

/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/export/_unlift.py:81: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  getattr_node = gm.graph.get_attr(lifted_node)
/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1790: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/export/_unlift.py:330: UserWarning: A model attribute `lifted_tensor_0` requires gradient. but it's not properly registered as a parameter. torch.export will detach it and treat it as a constant tensor but please register it as parameter instead.

...

  File "/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/_guards.py", line 1054, in detect_fake_mode
    assert fake_mode is m, (
           ^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7fd3f1342fd0>) from tracing context 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7fd3f3d29850>) from fake tensor input 3

Versions

Collecting environment information...
PyTorch version: 2.7.0.dev20250207+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.11.11 | packaged by conda-forge | (main, Dec  5 2024, 14:17:24) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.77
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A1000 6GB Laptop GPU
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               20
On-line CPU(s) list:                  0-19
Vendor ID:                            GenuineIntel
Model name:                           13th Gen Intel(R) Core(TM) i7-13800H
CPU family:                           6
Model:                                186
Thread(s) per core:                   2
Core(s) per socket:                   10
Socket(s):                            1
Stepping:                             2
BogoMIPS:                             5836.79
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization:
6357
                       VT-x
Hypervisor vendor:                    Microsoft
Virtualization type:                  full
L1d cache:                            480 KiB (10 instances)
L1i cache:                            320 KiB (10 instances)
L2 cache:                             12.5 MiB (10 instances)
L3 cache:                             24 MiB (1 instance)
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.2.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250207+cu124
[conda] Could not collect

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0