Open
Description
🐛 Describe the bug
Not completely a torch.compile
problem, more of a torch.export.export
problem when used with torch.autograd.grad
. We have seen minimal working examples of torch.export.export
working for code that includes x.requires_grad() ... torch.autograd.grad(y, [x])
, but some operations don't get exported correctly. Here is a minimal script that fails on PyTorch nightly 2.7.0 for a bunch of operations found so far (and a minimal one that succeeds).
import torch
class Model(torch.nn.Module):
def __init__(self, num_channels):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(num_channels))
def forward(self, x, y):
x.requires_grad_()
# this works
# z = (x * y * self.weight).square().sum()
# 1. `sqrt` fails
z = (x * y * self.weight).square().sum().sqrt()
# 2. `rsqrt` fails
# z = (x * y * self.weight).square().sum().rsqrt()
# 3. `exp` fails
# z = (x * y * self.weight).square().sum().exp()
# 4. `torch.linalg.norm` fails
# z = torch.linalg.norm(x * self.weight, dim=-1).sum()
# 5. using `matmul` with broadcasting fails
# note that this doesn't use `self.weight` and would error out anyway if we reach `loss.backward()`, but is here to show the error during `export`
# b1n, bn1 -> b11
# z = torch.matmul(x.unsqueeze(1), y.unsqueeze(2)).square().sum()
# 6. using `torch.nn.functional.linear` fails
# z = torch.nn.functional.linear(x, torch.outer(self.weight, self.weight), None).sum()
grad = torch.autograd.grad(z, [x])[0]
return grad
device = "cuda"
num_batch = 512
num_channels = 256
x = torch.randn(num_batch, num_channels, dtype=torch.float32, device=device)
y = torch.randn(num_batch, num_channels, dtype=torch.float32, device=device)
model = Model(num_channels).to(device=device)
eager_out = model(x, y)
print(eager_out)
batch_dim = torch.export.Dim("batch", min=1, max=1024)
exported = torch.export.export(
model,
(
x,
y,
),
strict=False,
dynamic_shapes={"x": {0: batch_dim}, "y": {0: batch_dim}},
)
model = torch.compile(exported.module())
out = model(x, y)
loss = out.square().mean()
loss.backward()
print(model.weight.grad)
Error logs
The error signature for all these cases look like the following. I'm guessing the backward pass of these problematic operations are not very export
friendly?
/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/export/_unlift.py:81: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
getattr_node = gm.graph.get_attr(lifted_node)
/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1790: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/export/_unlift.py:330: UserWarning: A model attribute `lifted_tensor_0` requires gradient. but it's not properly registered as a parameter. torch.export will detach it and treat it as a constant tensor but please register it as parameter instead.
...
File "/home/cwtan/micromamba/envs/torch-nightly/lib/python3.11/site-packages/torch/_guards.py", line 1054, in detect_fake_mode
assert fake_mode is m, (
^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7fd3f1342fd0>) from tracing context 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7fd3f3d29850>) from fake tensor input 3
Versions
Collecting environment information...
PyTorch version: 2.7.0.dev20250207+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.11.11 | packaged by conda-forge | (main, Dec 5 2024, 14:17:24) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.77
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A1000 6GB Laptop GPU
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 13th Gen Intel(R) Core(TM) i7-13800H
CPU family: 6
Model: 186
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 5836.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization:
6357
VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 24 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.2.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250207+cu124
[conda] Could not collect
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4