Open
Description
🐛 Describe the bug
Hi there,
I’m reporting an issue where torch.export.export()
fails on GPU when exporting a model that contains an nn.LSTM
layer. The same model exports successfully on CPU, which suggests a GPU-specific tracing or fake tensor handling issue. The error message references custom kernels, which seems unrelated to this model.
Minimal Reproduction Script
import torch
import torch.nn as nn
class SimpleLSTMModel(nn.Module):
def __init__(self, input_size=10, hidden_size=20, num_layers=1):
super(SimpleLSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, 1)
def forward(self, x):
out, _ = self.lstm(x)
out = self.linear(out)
return out
if __name__ == "__main__":
model = SimpleLSTMModel()
example_input = (torch.randn(5, 3, 10),)
print("Attempting export on CPU...")
try:
model_cpu = model.to("cpu")
exported_cpu = torch.export.export(model_cpu, example_input)
print("CPU export succeeded.")
except Exception as e:
print("CPU export failed:", e)
if torch.cuda.is_available():
print("Attempting export on GPU...")
try:
model_gpu = model.to("cuda")
input_gpu = (example_input[0].to("cuda"),)
exported_gpu = torch.export.export(model_gpu, input_gpu)
print("GPU export succeeded.")
except Exception as e:
print("GPU export failed:", e)
else:
print("GPU not available. Skipping GPU export.")
Output
Attempting export on CPU...
CPU export succeeded.
Attempting export on GPU...
GPU export failed: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
The error is likely caused by how torch.export
interacts with GPU-backed LSTM internals. It's possible that torch.export
is not properly handling fused GPU kernels or CUDNN ops when fake tensors are involved during graph capture.
Versions
Click to expand
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4