-
Notifications
You must be signed in to change notification settings - Fork 24.1k
In a multithreading scenario, the dynamic library obtained from aot_compile encounters a CUDA illegal memory access error during inference. #132475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Can you share a standalone repro code? That will help us to debug. One thing to notice is that AOTI runtime does its own thread pool management, which might interfere with your thread pool, see a similar discussion but on AOTI for CPU at #134946 . |
Sorry, I was not able to reproduce it in a multithreaded environment, it only appeared in the TensorFlow1.x environment. But I compiled a PyTorch2.3.0 package with DEBUG option and got a more complete stack trace.
Compared to the previous stack, the error appears in memcpy_and_sync |
The cause of the IMA (Illegal Memory Access) error is that the dynamic library generated by torch._export.aot_compile in C++ accepts at::Tensor as input but does not check the shape or dtype of the at::Tensor. For instance, if the parameter is supposed to be a torch.float64 tensor but I pass in a torch.float32 tensor, the dynamic library will still continue to run. However, after a memory page swap or other similar operation, the difference in bit size between float32 and float64 could lead to out-of-bounds access, corrupting the GPU environment. Because the IMA error is unrecoverable and fatal, in a multithreaded environment, all subsequent operations on the GPU will throw this exception, causing the program to crash. The stack trace for each crash varies depending on which thread accesses the GPU first. The consequence is that it becomes difficult to pinpoint the source of the problem. |
🐛 Describe the bug
I compiled my nn.Module on CUDA into a dynamic library using aot_compile, wrapped it into a TensorFlow operator for execution, and the inference process is as follows: first, copy the at::Tensor from CPU to CUDA, then execute inference with AotiContainerRunner, and finally copy the inference result back to the CPU from CUDA.
This function will be executed by a TensorFlow thread pool, making it parallel. However, the AOTIContainerRunner needs to run auto* model = get_available_model(); first, so its parallelism limit is fixed (I've set model_num = 40).
When QPS is below 120, everything works fine. But once QPS exceeds 120, the error 'CUDA error: an illegal memory access was encountered' occurs immediately, and GPU usage drops to zero. My GPU has 24GB of memory, but it's using less than 2GB at that time. Switching to another server with the same GPU didn't help.

Check the stack information at the button; the key error is an illegal memory access in frame 18 due to To(cuda).
frame #18: at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const + 0x1bb (0x55d57b1cf533 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
The corresponding C++ code is this line.
I put the full stack information at the very end. Initially, I suspected that the To('cuda') operation was not thread-safe, so I added a mutex to that line, but it didn't help. Later, I wondered if there were too many Tensors being copied to the GPU, causing AOTIContainerRunner to lag behind. To address this, I modified the Torch code to serialize the data copy to the GPU and model->run as follows:

Still to no avail.
I suspected that AOTIContainerRunner might have modified the actual memory area pointed to by Tensor, so I called the .clone method to create a deep copy before using To('cuda'), but it still didn't work.
Finally, I added a lock to ensure that only one thread could complete the following three steps before releasing the lock, allowing the next thread to proceed:
This solved the issue of illegal memory access.
In this case, the maximum GPU utilization reached 57%, but it is still too low. Since it is running on a single thread, there is still a lot of room for improvement in QPS. Now I have confirmed that there are no shared resources between my threads. The TensorFlow I'm using is the CPU version, so it doesn't involve managing CUDA resources. The management of CUDA resources should be properly encapsulated within the To('cuda') method of libtorch_cpu.so. I suspect that the maximum GPU memory usage might be limited internally by aotiContainerRunner.
Versions
$python collect_env.py
Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Alibaba Group Enterprise Linux Server 7.2 (Paladin) (x86_64)
GCC version: (GCC) 10.2.1 20200825 (Alibaba 10.2.1-3 2.17)
Clang version: 13.0.1 (Alibaba 13.0.1-2.fix20240307155812.alios7 77024452014c2872fa1c55b8456488fcbf309dcc)
CMake version: Could not collect
Libc version: glibc-2.32
Python version: 3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:53:36) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-4.19.91-009.ali4000.alios7.x86_64-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 12.3.52
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz
Stepping: 6
CPU MHz: 2899.991
CPU max MHz: 3500.0000
CPU min MHz: 800.0000
BogoMIPS: 5800.00
Virtualization: VT-x
L1d cache: 48K
L1i cache: 32K
L2 cache: 1280K
L3 cache: 49152K
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid fsrm md_clear pconfig flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.3.1
[pip3] torch_tensorrt==2.3.0
[pip3] torchvision==0.18.1
[pip3] triton==2.3.1
[conda] numpy 1.24.4 pypi_0 pypi
[conda] torch 2.3.1 pypi_0 pypi
[conda] torch-tensorrt 2.3.0 pypi_0 pypi
[conda] torchvision 0.18.1 pypi_0 pypi
[conda] triton 2.3.1 pypi_0 pypi
(torch231_cuda121)
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @desertfire @chenyang78
The text was updated successfully, but these errors were encountered: