Open
Description
🐛 Describe the bug
I was trying to use the transformer described in this tutorial with torch.compile
and nested tensors, but encountered a bug when backpropagating.
Here is a minimal example which leads to the same error:
import torch
from torch import nn
import torch.nn.functional as F
DEVICE = "cuda"
class DummyTransformer(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.linear = nn.Linear(d_model, d_model)
def forward(self, x):
x = self.linear(x)
x = x.unflatten(-1, [1, self.d_model]).transpose(1, 2)
x = F.scaled_dot_product_attention(x, x, x)
return x
x = torch.nested.nested_tensor_from_jagged(
torch.randn(11, 3, device=DEVICE),
offsets=torch.tensor([0, 5, 11], dtype=torch.int64, device=DEVICE),
)
layer = DummyTransformer(d_model=3).to(DEVICE)
layer = torch.compile(layer)
y = layer(x)
y.sum().backward()
Traceback (most recent call last):
File "/scratch/big/home/piegue/chan_inv_clf/report_issue_compile_nested_transformer.py", line 34, in <module>
y.sum().backward()
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/autograd/function.py", line 311, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2157, in backward
all_args = _backward_prologue_functional(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1665, in _backward_prologue_functional
flat_processed_tangents = list(
^^^^^
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1668, in <genexpr>
AOTDispatchAutograd.process_runtime_tangent(
File "/home/piegue/miniconda3/envs/torch/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1930, in process_runtime_tangent
assert len(meta.attrs) == len(runtime_subclass_keys)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Versions
It failed with torch version 2.7.1:
PyTorch version: 2.7.1+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-190-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 10.1.243
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Quadro RTX 6000
GPU 1: Quadro RTX 6000
GPU 2: Quadro RTX 6000
GPU 3: Quadro RTX 6000
GPU 4: Quadro RTX 6000
GPU 5: Quadro RTX 6000
GPU 6: Quadro RTX 6000
GPU 7: Quadro RTX 6000
GPU 8: Quadro RTX 6000
GPU 9: Quadro RTX 6000
Nvidia driver version: 535.54.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.4.dpkg-new
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.4.dpkg-new
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.4.dpkg-new
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.4.dpkg-new
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.4.dpkg-new
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.4.dpkg-new
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 64
On-line CPU(s) list: 0-63
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Gold 5218 CPU @ 2.30GHz
Stepping: 7
CPU MHz: 1094.088
CPU max MHz: 3900.0000
CPU min MHz: 1000.0000
BogoMIPS: 4600.00
Virtualization: VT-x
L1d cache: 1 MiB
L1i cache: 1 MiB
L2 cache: 32 MiB
L3 cache: 44 MiB
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47,49,51,53,55,57,59,61,63
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI Vulnerable, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm co
nstant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1
sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_en
hanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd a
vx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77 [pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-lightning==2.2.2
[pip3] pytorch-metric-learning==2.8.1
[pip3] torch==2.7.1
[pip3] torchaudio==2.7.1
[pip3] torchdata==0.7.1
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.2
[pip3] torchtext==0.17.1
[pip3] triton==3.3.1
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.5.1.17 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.2 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi
[conda] pytorch-lightning 2.2.2 pypi_0 pypi
[conda] pytorch-metric-learning 2.8.1 pypi_0 pypi
[conda] torch 2.7.1 pypi_0 pypi
[conda] torchaudio 2.7.1 pypi_0 pypi
[conda] torchdata 0.7.1 pypi_0 pypi
[conda] torchinfo 1.8.0 pypi_0 pypi
[conda] torchmetrics 1.3.2 pypi_0 pypi
[conda] torchtext 0.17.1 pypi_0 pypi
[conda] triton 3.3.1 pypi_0 pypi
I also tried with the latest nightly version, but it lead to the same result:
PyTorch version: 2.8.0.dev20250607+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
[...]
Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-triton==3.3.1+gitc8757738
[pip3] torch==2.8.0.dev20250607+cu126
[pip3] torchaudio==2.8.0.dev20250607+cu126
[pip3] torchvision==0.23.0.dev20250607+cu126
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.5.1.17 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi
[conda] pytorch-triton 3.3.1+gitc8757738 pypi_0 pypi
[conda] torch 2.8.0.dev20250607+cu126 pypi_0 pypi
[conda] torchaudio 2.8.0.dev20250607+cu126 pypi_0 pypi
[conda] torchvision 0.23.0.dev20250607+cu126 pypi_0 pypi
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @chauhang @penguinwu