8000 Memory leak in multi-thread inference · Issue #64412 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Memory leak in multi-thread inference #64412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mrshenli opened this issue Sep 2, 2021 · 8 comments
Open

Memory leak in multi-thread inference #64412

mrshenli opened this issue Sep 2, 2021 · 8 comments
Labels
high priority module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: multithreading Related to issues that occur when running on multiple CPU threads module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mrshenli
Copy link
Contributor
mrshenli commented Sep 2, 2021

Symptom

The memory usage of a process significantly increases when using more threads. It's likely PyTorch or some 3rd-party dependencies created expensive (much more than the ~8MB stack size) thread-local states. See #61920 for more discussion. I tried build without MKLDNN, doesn't help. This issue seems has existed for a long time (https://discuss.pytorch.org/t/heap-size-increase-constantly-when-inference-with-new-thread/57621).

Repro

The following code creates 20 threads and run inference through the same ResNet50 model instance using the same input tensor. The process memory consumption keeps increasing when more threads are used.

import os
import threading
import psutil

import torch

from torchvision import models

batch_size = 120
image_w = 128
image_h = 128

def model_forward(model, input):
    return model(input)

if __name__=="__main__":
    model = models.resnet50()
    inputs = torch.randn(batch_size, 3, image_w, image_h)
    process = psutil.Process(os.getpid())
    for i in range(20):
        thread = threading.Thread(target=model_forward, args=(model, inputs))
        thread.start()
        thread.join()
        print(f"Round {i}: memory {process.memory_info().rss // 1000000} MB")
Round 0: memory 1491 MB
Round 1: memory 3261 MB
Round 2: memory 4242 MB
Round 3: memory 5342 MB
Round 4: memory 5423 MB
Round 5: memory 6220 MB
Round 6: memory 6635 MB
Round 7: memory 6915 MB
Round 8: memory 8206 MB
Round 9: memory 7646 MB
Round 10: memory 8966 MB
Round 11: memory 9125 MB
Round 12: memory 9690 MB
Round 13: memory 8881 MB
Round 14: memory 10648 MB
Round 15: memory 11228 MB
Round 16: memory 11716 MB
Round 17: memory 11000 MB
Round 18: memory 12270 MB
Round 19: memory 10994 MB

The same ResNet50 inference code with a single thread does not hit this issue.

import os
import threading
import psutil

import torch

from torchvision import models

def model_forward(model, input):
    return model(input)

if __name__=="__main__":
    model = models.resnet50()
    inputs = torch.randn(batch_size, 3, image_w, image_h)
    process = psutil.Process(os.getpid())
    for i in range(20):
        model_forward(model, inputs)
        print(f"Round {i}: memory {process.memory_info().rss // 1000000} MB")
Round 0: memory 1423 MB
Round 1: memory 1416 MB
Round 2: memory 1424 MB
Round 3: memory 2089 MB
Round 4: memory 2112 MB
Round 5: memory 2725 MB
Round 6: memory 2746 MB
Round 7: memory 2075 MB
Round 8: memory 2080 MB
Round 9: memory 1403 MB
Round 10: memory 2089 MB
Round 11: memory 2715 MB
Round 12: memory 3282 MB
Round 13: memory 3267 MB
Round 14: memory 3138 MB
Round 15: memory 1404 MB
Round 16: memory 2026 MB
Round 17: memory 1405 MB
Round 18: memory 1403 MB
Round 19: memory 1408 MB

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @VitalyFedyunin @ngimel @heitorschueroff

@mrshenli mrshenli added high priority module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: multithreading Related to issues that occur when running on multiple CPU threads labels Sep 2, 2021
@mrshenli
Copy link
Contributor Author
mrshenli commented Sep 2, 2021

This was also reported in #26165 and #24237

@mrshenli mrshenli added the module: performance Issues related to performance, either of kernel code or framework glue label Sep 2, 2021
@ezyang
Copy link
Contributor
ezyang commented Sep 2, 2021

@mrshenli and I talked about this a little bit and (1) nailing this down to the specific operator that's leaking would be a big help, (2) verifying it doesn't leak on CUDA and (3) we strongly suspect it's MKL/MKLDNN related, because we don't really do heavy thread caching inside the library itself

@cloudhan
Copy link
Contributor
cloudhan commented Sep 6, 2021

@mrshenli can you verify what I reported in #64535?

What is you CPU model?
What happens if you set MKL_NUM_THREADS or OMP_NUM_THREADS env var?

@mrshenli
Copy link
Contributor Author
mrshenli commented Sep 7, 2021

Hey @cloudhan

I wasn't able to reproduce the leak. My CPU info is: model name : Intel(R) Xeon(R) Gold 6138 CPU @ 2.00GHz

import torch
import psutil
import os

def run():
    # input_size * hidden_size >= 2048 cause leak
    # also observed extreme performace degeneration
    input_size = 16
    hidden_size = 128
    gru = torch.nn.GRU(input_size, hidden_size)
    inputs = torch.randn(1, 1, input_size)

    counter = 0
    process = psutil.Process(os.getpid())
    while True:
        counter += 1
        _, out = gru(inputs)
        if (counter % 100) == 0:
            print(f"Round {counter}: memory {process.memory_info().rss // 1000000} MB")


if __name__ == "__main__":
    with torch.no_grad():
        run()

Here is what I observed. And I also tried setting MKL_NUM_THREADS and OMP_NUM_THREADS to difference values, there is no leak in my local env.

Round 100: memory 150 MB  
Round 200: memory 150 MB            
Round 300: memory 150 MB  
Round 400: memory 150 MB            
Round 500: memory 150 MB  
Round 600: memory 150 MB            
Round 700: memory 150 MB  
Round 800: memory 150 MB            
Round 900: memory 150 MB  

@cloudhan
Copy link
Contributor
cloudhan commented Sep 7, 2021

Fair enough, 6138 supports AVX512, so no leak is expected.

8000
@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 13, 2021
@WilliamTambellini
Copy link
Contributor

Experiencing similar mem leak with libtorch 1.7.

@wenzhe-nrv
Copy link
Collaborator

@cloudhan I could reproduce the memory leak using above code even with a Xeon 8280 (avx512 supported).
What fixed the problem for me is installing jemalloc. After that, the memory usage became stable.

Though I still saw some leakage when mkl (linear op) or oneDNN/ideep (conv op) were used. So I'm still checking those.

@ezyang I wonder if there's any update on this issue from Pytorch team. Thanks.

@Tacx79
Copy link
Tacx79 commented Oct 5, 2023

I have similar issue in custom dataloader with threads and image transforms. As long as I move the data out of the thread in numpy array and apply transforms in the main thread it runs fine, when I move transforms to the thread the leak appears and memory usage slowly grows over the epoch leading to OOM (vram or ram, depending if I apply transforms on GPU or CPU).

  1. Exporting data from thread as numpy array and applying transforms in main thread - no leak
  2. Applying transforms inside the thread AND running data through the model - big leak, 7-8 GB over 20k 320x320 images, gc.get_objects() shows growing list of tensors of the shape of images and labels
  3. Applying transforms inside the thread AND NOT running data through the model - small leak, few hundred MB over 20k images
  4. Exporting data from thread as a tensor and applying transforms in main thread - small leak - similar to point 3 but memory usage doesn't skyrocket when running data through the model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: multithreading Related to issues that occur when running on multiple CPU threads module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants
0