8000 Failing to free a tensor allocated while a torch.cuda.Mempool is active results in that tensor being freed with cudaFree() rather than the custom free function. · Issue #146431 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Failing to free a tensor allocated while a torch.cuda.Mempool is active results in that tensor being freed with cudaFree() rather than the custom free function. #146431
Closed
@galv

Description

@galv

🐛 Describe the bug

Currently, in the main branch, torch.cuda.MemPool was recently merged, which should allow for us to mix and match allocators in pytorch, but I have found that the current implementation is broken.

I was inspired by recent work to "suspend" temporary GPU buffers in cuda graphs used by cuda graphs, that both sglang and vLLM seem to have explored very recently:

sgl-project/sglang#2630
vllm-project/vllm#11743

CUDAGraphs "bake in" their virtual addresses. By using the CUDA VMM APIs, https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html, we can deallocate the physical backing memory without deallocating the virtual addresses, thus fixing the problem that cudagraphs hold memory that other cuda graphs or non cuda graph code cannot use. This is an alternative to sharing a mempool id, but more flexible in my opinion, because we can suspend "non-cudagraph" memory allocations (i.e., allocations not made during cuda graph capture; KV cache is the most obvious beneficiary of this) as well with this. Additionally, we can explore suspending the cudaGraphExec_t itself, which sits in GPU memory and can be substantial in size because each node is 10kiB and it is not abnormal to have a cuda graph with 10,000 or more nodes for large workloads, which is ~100 MiB. But that is not the immediate use case.

The ability to suspend memory when swapping between inference and training in the same process is important for new RL workloads. Otherwise, you may not be able to reach full performance in either training or inference due to memory constraints. It is also important for LLM inference use cases where it is normal to create a unique cuda graph for each input shape bucket, but this use case is sort of covered by sharing memory pool ids between cuda graphs.

I am not 100% certain, but I believe that @youkaichao encountered a similar issue to mine here: #145168

My reproducer is on commit 1c16cf7.

I applied this diff first to get better visibility:

diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp
index 9f335c5fc1e..c02b7705cf5 100644
--- a/c10/cuda/CUDACachingAllocator.cpp
+++ b/c10/cuda/CUDACachingAllocator.cpp
@@ -2941,8 +2941,10 @@ class DeviceCachingAllocator {
 
       // If there is an active mempool with a given allocator,
       // we use the given allocator's delete function.
+      std::cout << "GALVEZ:custom allocator free" << std::endl;
       active_pool->allocator()->raw_delete((void*)block->ptr);
     } else {
+      std::cout << "GALVEZ:cudaFree()" << std::endl;
       C10_CUDA_CHECK(cudaFree((void*)block->ptr));
     }
     total_allocated_memory -= block->size;
diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
index 5e4a637f851..4a8aa614e0b 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
@@ -5390,6 +5390,7 @@ static void* _ncclMemAlloc(size_t size, int device, void* stream) {
       false, "NCCL mem allocator is not supported in this NCCL version");
 #else
   LOG(INFO) << "NCCL mem allocator: allocating " << size << " bytes";
+  std::cout << "GALVEZ:_ncclMemAlloc()" << std::endl;
   at::cuda::OptionalCUDAGuard gpuGuard(device);
   void* ptr = nullptr;
   TORCH_CHECK(ncclMemAlloc(&ptr, size) == ncclSuccess, "ncclMemAlloc failed");
@@ -5404,6 +5405,7 @@ static void _ncclMemFree(void* ptr, size_t size, int device, void* stream) {
       false, "NCCL mem allocator is not supported in this NCCL version");
 #else
   LOG(INFO) << "NCCL mem allocator: freeing " << size << " bytes";
+  std::cout << "GALVEZ:_ncclMemFree()" << std::endl;
   at::cuda::OptionalCUDAGuard gpuGuard(device);
   TORCH_CHECK(ncclMemFree(ptr) == ncclSuccess, "ncclMemFree failed");
 #endif // NCCL_HAS_MEM_ALLOC

I built like this:

USE_GLOG=ON USE_MEM_EFF_ATTENTION=0 USE_FLASH_ATTENTION=0 USE_DISTRIBUTED=1 USE_MKLDNN=0 BUILD_TEST=1 USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 DEBUG=1 TORCH_CUDA_ARCH_LIST="8.9" CMAKE_C_COMPILER_LAUNCHER=ccache CMAKE_CXX_COMPILER_LAUNCHER=ccache CMAKE_CUDA_COMPILER_LAUNCHER=ccache python setup.py develop

When you run:

import torch
import torch.distributed as c10d
import torch._logging
import logging
import os

# torch._logging.set_logs(all=logging.INFO) 

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "7999"

c10d.init_process_group(
    backend="nccl", rank=0, world_size=1,
)
device = torch.device(f"cuda:0")
torch.cuda.set_device(0)
pg = c10d.distributed_c10d._get_default_group()
backend = pg._get_backend(torch.device(device))

def f(backend):
    pool = torch.cuda.MemPool(backend.mem_allocator)

    # allocate memory with ncclMemAlloc 
    with torch.cuda.use_mem_pool(pool):
        x = torch.arange(1024 * 1024 * 2, device="cuda")
    # Note: pool will be destroyed upon function return, but y, which                                                                                                                                       
    # was allocated via the pool is still alive.                                                                                                                                                            
    return x

x = f(backend)

del x

# calls cudaFree() on x                                                                                                                                                                                     
torch.cuda.empty_cache()

c10d.destroy_process_group()

What you get is the following stdout:

(pytorch-5) dgalvez@dgalvez-dvt-01:~/code/asr/pytorch-5$ python repros/basic_mempool.py 
GALVEZ:_ncclMemAlloc()
GALVEZ:cudaFree()

This means that _ncclMemAlloc() is not being matched to _ncclMemFree(). This is a serious error. There is no error reported, but clearly it is not okay.

I looked into it, and the problem is that the current logic assumes that the memory pool used to do allocation is still active when it comes time for the tensor to be freed:

auto* pool = block->pool;
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
// Ensure that active_pool and pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
// If there is an active mempool with a given allocator,
// we use the given allocator's delete function.
active_pool->allocator()->raw_delete((void*)block->ptr);
} else {
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
}

This is not good for multiple reasons:

  1. If the mempool used to do the original allocation is not active, then an error should reported. You should not fallback to cudaFree(), which is not necessarily correct for your custom memory allocator. We are just getting lucky that there is no error being reported right now.
  2. cudaFree() is not called when a tensor is destroyed. It is called when the backing storage is destroyed, e.g. when torch.cuda.empty_cache()
    is called or a cudaMalloc fails due to memory fragmentation. The user does not have control over this in general. Therefore, we cannot in general write code that always deletes backing memory when the appropriate memory pool is active without calling torch.cuda.empty_cache() at the end of each memory pool's "with" region, which synchronizes the cuda stream. I think the current design is therefore fundamentally flawed.

What I propose instead is the following:

  • We require that a CUDAPluggableAllocator be live for the duration of all of its allocations' lifetimes. This seems easy to enforce to me given that the current design pattern is to use static instances of them right now.
  • Each PrivatePool will have a new member pointing to the allocator used by the memory pool that created that private pool:
    struct PrivatePool {
    PrivatePool()
    : large_blocks(/*small=*/false, this),
    small_blocks(/*small=*/true, this) {}
    PrivatePool(const PrivatePool&) = delete;
    PrivatePool(PrivatePool&&) = delete;
    PrivatePool& operator=(const PrivatePool&) = delete;
    PrivatePool& operator=(PrivatePool&&) = delete;
    ~PrivatePool() = default;
    // Number of live graphs using this pool
    int use_count{1};
    // Number of unfreed cudaMallocs made for this pool. When use_count and
    // cudaMalloc_count drop to zero, we can delete this PrivatePool from
    // graph_pools.
    int cudaMalloc_count{0};
    // Instead of maintaining private BlockPools here, I could stuff all blocks
    // (private or no) into the top-level large_blocks and small_blocks, and
    // distinguish private blocks by adding a "pool id" check above the stream
    // check in BlockComparator. BlockComparator is performance- critical though,
    // I'd rather not add more logic to it.
    BlockPool large_blocks;
    BlockPool small_blocks;
    };
    This is an acceptable pattern because there is a one to one relationship between memory pool IDs and private pools IIUC:
    // Private pools for CUDA graphs
    ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
    graph_pools;
  • The above two points allow us to always call the PluggableAllocator's free function regardless of whether the MemPool used for allocation is active at the time of freeing.

This seems much more correct and much less error prone, but I might be missing something.

What do you think @syed-ahmed @ezyang ?

Versions

PyTorch version: 2.7.0a0+git1c16cf7
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.9.21 (main, Dec 11 2024, 16:24:11) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-46-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.77
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L40S
Nvidia driver version: 560.35.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 45 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 25
On-line CPU(s) list: 0-24
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9454 48-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 1
Core(s) per socket: 1
Socket(s): 25
Stepping: 1
BogoMIPS: 5491.74
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext invpcid_single ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor fsrm flush_l1d
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 800 KiB (25 instances)
L1i cache: 800 KiB (25 instances)
L2 cache: 25 MiB (25 instances)
L3 cache: 800 MiB (25 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-24
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy==1.13.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] > [pip3] > [pip3] optree==0.13.0
[pip3] torch==2.7.0a0+git1c16cf7
[conda] numpy 1.22.4 pypi_0 pypi
[conda] optree 0.13.0 pypi_0 pypi
[conda] torch 2.7.0a0+git1c16cf7 dev_0

cc @ptrblck @msaroufim @eqy

Metadata

Metadata

Assignees

Labels

module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0