8000 In a multithreading scenario, the dynamic library obtained from aot_compile encounters a CUDA illegal memory access error during inference. · Issue #132475 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

In a multithreading scenario, the dynamic library obtained from aot_compile encounters a CUDA illegal memory access error during inference. #132475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
sujuyu opened this issue Aug 2, 2024 · 4 comments
Assignees
Labels
high priority module: aotinductor aot inductor module: crash Problem manifests as a hard crash, as opposed to a RuntimeError needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@sujuyu
Copy link
sujuyu commented Aug 2, 2024

🐛 Describe the bug

I compiled my nn.Module on CUDA into a dynamic library using aot_compile, wrapped it into a TensorFlow operator for execution, and the inference process is as follows: first, copy the at::Tensor from CPU to CUDA, then execute inference with AotiContainerRunner, and finally copy the inference result back to the CPU from CUDA.

std::vector<at::Tensor> realInputs;
realInputs.reserve(inputs.size());
for (size_t idx = 0; idx < inputs.size(); ++idx) {
    at::Tensor realInput = inputs[idx];
    realInput = realInput.to(_modelParams.actualDevice); //upload to CUDA
    realInputs.push_back(realInput);
}
vector<at::Tensor> realOutputs = _aotiModelContainerRunner->run(realInputs);
for (size_t idx = 0; idx < realOutputs.size(); ++idx) {
    outputs.push_back(realOutputs[idx].to(torch::kCPU));
}

This function will be executed by a TensorFlow thread pool, making it parallel. However, the AOTIContainerRunner needs to run auto* model = get_available_model(); first, so its parallelism limit is fixed (I've set model_num = 40).

_aotiModelContainerRunner = std::make_shared<torch::inductor::AOTIModelContainerRunnerCuda>(
                          modelFilePath,
                          aotModelNum,  // 40
                          "cuda", 
           FileUtil::getParentDir(_modelParams.aotModelFilePath)
                    );

When QPS is below 120, everything works fine. But once QPS exceeds 120, the error 'CUDA error: an illegal memory access was encountered' occurs immediately, and GPU usage drops to zero. My GPU has 24GB of memory, but it's using less than 2GB at that time. Switching to another server with the same GPU didn't help.
image
Check the stack information at the button; the key error is an illegal memory access in frame 18 due to To(cuda).

frame #18: at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const + 0x1bb (0x55d57b1cf533 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)

The corresponding C++ code is this line.

realInput = realInput.to(_modelParams.actualDevice);

I put the full stack information at the very end. Initially, I suspected that the To('cuda') operation was not thread-safe, so I added a mutex to that line, but it didn't help. Later, I wondered if there were too many Tensors being copied to the GPU, causing AOTIContainerRunner to lag behind. To address this, I modified the Torch code to serialize the data copy to the GPU and model->run as follows:
image
Still to no avail.
I suspected that AOTIContainerRunner might have modified the actual memory area pointed to by Tensor, so I called the .clone method to create a deep copy before using To('cuda'), but it still didn't work.

Finally, I added a lock to ensure that only one thread could complete the following three steps before releasing the lock, allowing the next thread to proceed:

  1. Converting a CPU Tensor to a CUDA Tensor
  2. Performing AOTI inference
  3. Converting the CUDA Tensor back to a CPU Tensor
    This solved the issue of illegal memory access.

In this case, the maximum GPU utilization reached 57%, but it is still too low. Since it is running on a single thread, there is still a lot of room for improvement in QPS. Now I have confirmed that there are no shared resources between my threads. The TensorFlow I'm using is the CPU version, so it doesn't involve managing CUDA resources. The management of CUDA resources should be properly encapsulated within the To('cuda') method of libtorch_cpu.so. I suspect that the maximum GPU memory usage might be limited internally by aotiContainerRunner.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fbc7692b897 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fbc768dbb25 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fbc7684d718 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libc10_cuda.so)
frame #3: <unknown function> + 0x1915edc (0x7fbc4425dedc in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cuda.so)
frame #4: <unknown function> + 0x1cab4bb (0x7fbc2cc794bb in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #5: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0x62 (0x7fbc2cc7ae22 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #6: at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool) + 0x15c (0x7fbc2da0568c in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #7: at::native::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0x1abc (0x7fbc2cf91afc in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #8: <unknown function> + 0x2de6f7b (0x7fbc2ddb4f7b in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #9: at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0xf5 (0x7fbc2d4bee75 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #10: <unknown function> + 0x2c11243 (0x7fbc2dbdf243 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #11: at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0xf5 (0x7fbc2d4bee75 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #12: <unknown function> + 0x44f6b94 (0x7fbc2f4c4b94 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #13: <unknown function> + 0x44f6ffe (0x7fbc2f4c4ffe in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #14: at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0x1eb (0x7fbc2d54ef9b in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #15: at::native::to(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>) + 0x11d (0x7fbc2cf8e77d in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #16: <unknown function> + 0x2fdfa11 (0x7fbc2dfada11 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #17: at::_ops::to_dtype_layout::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>) + 0x200 (0x7fbc2d6f5cb0 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #18: at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const + 0x1bb (0x55d57b1cf533 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
frame #19: suez::turing::PyTorchModelPredictor::aotPredict(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> >&) + 0x324 (0x55d57b1cb4b4 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
frame #20: suez::turing::PyTorchModelPredictor::aotPredict(std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> > const&, std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >&) + 0x67 (0x55d57b1ce84d in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
frame #21: suez::turing::PyTorchModelPredictor::predict(std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> > const&, std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >&) + 0x98 (0x55d57b1cf02a in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
frame #22: suez::turing::PyTorchModelPredictOp::Compute(tensorflow::OpKernelContext*) + 0xad9 (0x55d57b1c6fc7 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
frame #23: auto_scaler::AutoScalerLocalDevice::Compute(tensorflow::OpKernel*, tensorflow::OpKernelContext*) + 0x54 (0x55d5738edbd2 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
frame #24: <unknown function> + 0x1818436e (0x55d580e3a36e in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)
frame #25: <unknown function> + 0x181710f0 (0x55d580e270f0 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/bin/sap_server_d.bin)

Versions

$python collect_env.py
Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Alibaba Group Enterprise Linux Server 7.2 (Paladin) (x86_64)
GCC version: (GCC) 10.2.1 20200825 (Alibaba 10.2.1-3 2.17)
Clang version: 13.0.1 (Alibaba 13.0.1-2.fix20240307155812.alios7 77024452014c2872fa1c55b8456488fcbf309dcc)
CMake version: Could not collect
Libc version: glibc-2.32

Python version: 3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:53:36) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-4.19.91-009.ali4000.alios7.x86_64-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 12.3.52
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.5
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz
Stepping: 6
CPU MHz: 2899.991
CPU max MHz: 3500.0000
CPU min MHz: 800.0000
BogoMIPS: 5800.00
Virtualization: VT-x
L1d cache: 48K
L1i cache: 32K
L2 cache: 1280K
L3 cache: 49152K
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid fsrm md_clear pconfig flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.3.1
[pip3] torch_tensorrt==2.3.0
[pip3] torchvision==0.18.1
[pip3] triton==2.3.1
[conda] numpy 1.24.4 pypi_0 pypi
[conda] torch 2.3.1 pypi_0 pypi
[conda] torch-tensorrt 2.3.0 pypi_0 pypi
[conda] torchvision 0.18.1 pypi_0 pypi
[conda] triton 2.3.1 pypi_0 pypi
(torch231_cuda121)

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @desertfire @chenyang78

@malfet malfet added oncall: pt2 module: aotinductor aot inductor module: crash Problem manifests as a hard crash, as opposed to a RuntimeError labels Aug 2, 2024
@williamwen42
Copy link
Member

@desertfire

@williamwen42 williamwen42 added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 3, 2024
@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 1, 2024
@desertfire
Copy link
Contributor

Can you share a standalone repro code? That will help us to debug.

One thing to notice is that AOTI runtime does its own thread pool management, which might interfere with your thread pool, see a similar discussion but on AOTI for CPU at #134946 .

@mlazos mlazos added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Oct 15, 2024
@sujuyu
Copy link
Author
sujuyu commented Oct 16, 2024

n> + 0x1915edc (0x7fbc4425dedc in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cuda.so)
frame #4: + 0x1cab4bb (0x7fbc2cc794bb in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-hpai_na175-item.tgc_guide_rank_model_test-hpai_na175-item_partition_0_39_95/binary/usr/local/lib64/libtorch_cpu.so)
frame #5: at::native::co

Sorry, I was not able to reproduce it in a multithreaded environment, it only appeared in the TensorFlow1.x environment. But I compiled a PyTorch2.3.0 package with DEBUG option and got a more complete stack trace.

Exception raised from c10_cuda_check_implementation at /home/admin/zy429782/torch_folder/pytorch-2.3.1/c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: <unknown function> + 0x9c5b6 (0x7f94c57c35b6 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libc10.so)
frame #1: <unknown function> + 0x9c164 (0x7f94c57c3164 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libc10.so)
frame #2: <unknown function> + 0x9be08 (0x7f94c57c2e08 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libc10.so)
frame #3: <unknown function> + 0x9d323 (0x7f94c57c4323 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libc10.so)
frame #4: c10::Error::Error(c10::SourceLocation, std::string) + 0x28 (0x7f94c57c1ae2 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libc10.so)
frame #5: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x7e (0x7f94c57bfa61 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libc10.so)
frame #6: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x179 (0x7f94c5696c8a in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libc10_cuda.so)
frame #7: c10::cuda::memcpy_and_sync(void*, void const*, long, cudaMemcpyKind, CUstream_st*) + 0xc1 (0x7f94b80451ad in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cuda.so)
frame #8: <unknown function> + 0x3c5e3a4 (0x7f94b81183a4 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cuda.so)
frame #9: <unknown function> + 0x145ee0e (0x7f949f418e0e in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #10: <unknown function> + 0x145d63b (0x7f949f41763b in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #11: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0xc3 (0x7f949f41787c in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #12: <unknown function> + 0x3037a78 (0x7f94a0ff1a78 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #13: <unknown function> + 0x31236dc (0x7f94a10dd6dc in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #14: <unknown function> + 0x21845e8 (0x7f94a013e5e8 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #15: at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool) + 0x496 (0x7f94a088e768 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #16: at::Tensor::copy_(at::Tensor const&, bool) const + 0x2c (0x7f94b740b134 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cuda.so)
frame #17: at::native::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0x979 (0x7f949f7bda64 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #18: <unknown function> + 0x3041799 (0x7f94a0ffb799 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #19: <unknown function> + 0x3159361 (0x7f94a1113361 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #20: <unknown function> + 0x21869e8 (0x7f94a01409e8 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #21: <unknown function> + 0x203a51f (0x7f949fff451f in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #22: at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0x21b (0x7f949fee3d97 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #23: <unknown function> + 0x2d32155 (0x7f94a0cec155 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #24: <unknown function> + 0x2d50943 (0x7f94a0d0a943 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #25: <unknown function> + 0x21869e8 (0x7f94a01409e8 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #26: <unknown function> + 0x203a51f (0x7f949fff451f in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #27: at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0x21b (0x7f949fee3d97 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #28: <unknown function> + 0x4ff7720 (0x7f94a2fb1720 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #29: <unknown function> + 0x4eda6b5 (0x7f94a2e946b5 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #30: <unknown function> + 0x4edaa9a (0x7f94a2e94a9a in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #31: <unknown function> + 0x4fbd913 (0x7f94a2f77913 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #32: <unknown function> + 0x21869e8 (0x7f94a01409e8 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #33: at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>) + 0x79b (0x7f949fee3a05 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #34: <unknown function> + 0xeb022a (0x7f949ee6a22a in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #35: <unknown function> + 0x18040c4 (0x7f949f7be0c4 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #36: at::native::to(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>) + 0xba (0x7f949f7be531 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #37: <unknown function> + 0x3340463 (0x7f94a12fa463 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #38: <unknown function> + 0x34262f4 (0x7f94a13e02f4 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #39: <unknown function> + 0x24c87bc (0x7f94a04827bc in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #40: at::_ops::to_dtype_layout::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>) + 0x85a (0x7f94a0273132 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libtorch_cpu.so)
frame #41: at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const + 0x1bb (0x55bb19cdb9e5 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #42: suez::turing::PyTorchModelPredictor::aotPredict(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> >&) + 0x592 (0x55bb19cd815e in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #43: suez::turing::PyTorchModelPredictor::aotPredict(std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> > const&, std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >&) + 0x67 (0x55bb19cdacff in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #44: suez::turing::PyTorchModelPredictor::predict(std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> > const&, std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >&) + 0x98 (0x55bb19cdb4dc in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #45: suez::turing::PyTorchModelPredictOp::Compute(tensorflow::OpKernelContext*) + 0xad9 (0x55bb19cd2fc7 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #46: auto_scaler::AutoScalerLocalDevice::Compute(tensorflow::OpKernel*, tensorflow::OpKernelContext*) + 0x54 (0x55bb123f9bd2 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #47: <unknown function> + 0x1818554e (0x55bb1f94654e in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #48: <unknown function> + 0x181722d0 (0x55bb1f9332d0 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #49: std::function<void ()>::operator()() const + 0xe (0x55bb11a698a6 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #50: Eigen::ThreadPoolTempl<tensorflow::thread::EigenEnvironment>::WorkerLoop(int) + 0x951 (0x55bb27b1bc15 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #51: std::_Function_handler<void (), Eigen::ThreadPoolTempl<tensorflow::thread::EigenEnvironment>::ThreadPoolTempl(int, bool, tensorflow::thread::EigenEnvironment)::{lambda()#1}>::_M_invoke(std::_Any_data const&) + 0xf (0x55bb27b1bcb5 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #52: std::function<void ()>::operator()() const + 0xe (0x55bb11a698a6 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #53: std::_Function_handler<void (), tensorflow::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) + 0x2b (0x55bb27b189b2 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #54: std::thread::_State_impl<std::thread::_Invoker<std::tuple<std::function<void ()> > > >::_M_run() + 0x15 (0x55bb12405a9d in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/bin/sap_server_d.bin)
frame #55: <unknown function> + 0xd3b55 (0x7f945c8c3b55 in /home/admin/hippo/worker/slave/suezops_rtp_base_workspace_prod_na175_1_tgc_guide_rank_model_test-na175-item.tgc_guide_rank_model_test-na175-item_partition_0_39_85/binary/usr/local/lib64/libstdc++.so.6)
frame #56: <unknown function> + 0x93f9 (0x7f945cd6d3f9 in /usr/lib64/libpthread.so.0)
frame #57: clone + 0x43 (0x7f945c70c303 in /usr/lib64/libc.so.6)

Compared to the previous stack, the error appears in memcpy_and_sync

@sujuyu
Copy link
Author
sujuyu commented Oct 25, 2024

The cause of the IMA (Illegal Memory Access) error is that the dynamic library generated by torch._export.aot_compile in C++ accepts at::Tensor as input but does not check the shape or dtype of the at::Tensor. For instance, if the parameter is supposed to be a torch.float64 tensor but I pass in a torch.float32 tensor, the dynamic library will still continue to run. However, after a memory page swap or other similar operation, the difference in bit size between float32 and float64 could lead to out-of-bounds access, corrupting the GPU environment. Because the IMA error is unrecoverable and fatal, in a multithreaded environment, all subsequent operations on the GPU will throw this exception, causing the program to crash. The stack trace for each crash varies depending on which thread accesses the GPU first.

The consequence is that it becomes difficult to pinpoint the source of the problem.

@sujuyu sujuyu closed this as completed Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: aotinductor aot inductor module: crash Problem manifests as a hard crash, as opposed to a RuntimeError needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants
0