8000 [BUG] DataLoader low GPU utilization and extremely slow compared to manual batching · Issue #154318 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[BUG] DataLoader low GPU utilization and extremely slow compared to manual batching #154318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jobs-git opened this issue May 25, 2025 · 5 comments
Labels
module: data torch.utils.data module: dataloader Related to torch.utils.data.DataLoader and Sampler module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jobs-git
Copy link
jobs-git commented May 25, 2025

🐛 Describe the bug

DataLoader retrieves data about 7-22x or 50x slower with bfloat16 as compared to direct access even if direct access retrieves data from CPU.

Here is a reproducible sample:

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_features: int, out_features: int, layer_size: int, hidden_size: int):
        super (MLP, self).__init__()

        layers = [
            nn.Linear(in_features, hidden_size),
            nn.ReLU(), 
        ]

        for _ in range (layer_size - 1):
            layers.extend([
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
            ])

        layers.append(nn.Linear(hidden_size, out_features)) 

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

import torch
import torch.nn as nn
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_samples = 7_000_000
in_features = 100
out_features = 10
hidden_size = 512
layer_size = 2
batch_size = 1_048_576

X = torch.randn(num_samples, in_features) 

model = MLP(in_features, out_features, layer_size, hidden_size).to(device)

start_time = time.time()

with torch.no_grad():
    for i in range(0, num_samples, batch_size):
        batch_X = X[i : i + batch_size].to(device)
        output = model(batch_X) 

torch.cuda.synchronize()
end_time = time.time()

print(f"Direct Access Forward Pass Time: {end_time - start_time:.3f} seconds")
torch.cuda.empty_cache()

from torch.utils.data import DataLoader

dataloader = DataLoader(X, batch_size=batch_size, shuffle=True)

model = MLP(in_features, out_features, layer_size, hidden_size).to(device)

start_time = time.time()

with torch.no_grad():
    for batch_X in dataloader:
        output = model(batch_X.to(device)) 

torch.cuda.synchronize()
end_time = time.time()

print(f"DataLoader Forward Pass Time: {end_time - start_time:.3f} seconds")
torch.cuda.empty_cache()

results are as follows:

Direct Access Forward Pass Time: 0.640 seconds
DataLoader Forward Pass Time:   15.653 seconds

Other ref: https://stackoverflow.com/questions/76838721/iterating-over-pytorch-dataloader-slower-than-direct-dataset-access

with: shuffle=False, num_workers=8, prefetch_factor=8, pin_memory=True and batch_X.to(device, non_blocking=True)

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_features: int, out_features: int, layer_size: int, hidden_size: int):
        super (MLP, self).__init__()

        layers = [
            nn.Linear(in_features, hidden_size),
            nn.ReLU(), 
        ]

        for _ in range (layer_size - 1):
            layers.extend([
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
            ])

        layers.append(nn.Linear(hidden_size, out_features)) 

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

import torch
import torch.nn as nn
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_samples = 7_000_000
in_features =
8000
 100
out_features = 10
hidden_size = 512
layer_size = 2
batch_size = 512*2 #1_048_576

X = torch.randn(num_samples, in_features) 

model = MLP(in_features, out_features, layer_size, hidden_size).to(device)

start_time = time.time()

with torch.no_grad():
    for i in range(0, num_samples, batch_size):
        batch_X = X[i : i + batch_size].to(device, non_blocking = True)
        output = model(batch_X) 

torch.cuda.synchronize()
end_time = time.time()

print(f"Direct Access Forward Pass Time: {end_time - start_time:.3f} seconds")
torch.cuda.empty_cache()

from torch.utils.data import DataLoader

dataloader = DataLoader(X, batch_size=batch_size, shuffle=False, num_workers=12, prefetch_factor=12, pin_memory=True)

model = MLP(in_features, out_features, layer_size, hidden_size).to(device)

start_time = time.time()

with torch.no_grad():
    for batch_X in dataloader:
        output = model(batch_X.to(device, non_blocking=True)) 

torch.cuda.synchronize()
end_time = time.time()

print(f"DataLoader Forward Pass Time: {end_time - start_time:.3f} seconds")
torch.cuda.empty_cache()

results are as follows:

Direct Access Forward Pass Time: 0.608 seconds
DataLoader Forward Pass Time: 4.772 seconds

I have to reduce to batch size since DataLoader fails to complete due to unexpected bus error

Versions

2.7

cc @msaroufim @jerryzh168 @andrewkho @divyanshk @ssnl @VitalyFedyunin @dzhulgakov

@jobs-git jobs-git changed the title [BUG] DataLoader extremely slow [BUG] DataLoader is extremely slow May 25, 2025
@jobs-git jobs-git changed the title [BUG] DataLoader is extremely slow [BUG] DataLoader is extremely slow compared to manual batching May 26, 2025
@jbschlosser jbschlosser added module: performance Issues related to performance, either of kernel code or framework glue module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: data torch.utils.data labels May 27, 2025
@divyanshk
Copy link
Contributor

@jobs-git Curious, we do we have such large batch size ?

Also, looks like in the 'direct access' setup we do not shuffle, where we do in the Dataloader

@jbschlosser
Copy link
Contributor

Agreed it's not fair to compare shuffling vs. non-shuffling, but I can repro DataLoader slowness even with shuffle=False. AFAICT it's due to inefficient tensor indexing + collation done by DataLoader.

@jobs-git
Copy link
Author

@jobs-git Curious, we do we have such large batch size ?

Also, looks like in the 'direct access' setup we do not shuffle, where we do in the Dataloader

So we can accelerate model training.

Setting shuffle to false in Dataloader resolves to same outcome, it stave of 1 or 2s in Dataloader, but we still arrived at the same conclusion.

@divyanshk
Copy link
Contributor

@jobs-git The batch size is quite large - but if that works for your pipeline thats great!

I tried playing around with your script. Using shuffle=False, num_workers=8, prefetch_factor=8, pin_memory=True and batch_X.to(device, non_blocking=True) I go the difference to ~3x. This is a good test for the dataloader, compared to just reading of a list/tensor - but the dataloader can do more like applying a custom collating function per batch which is comes very handy when you want to apply custom transformations on your data before feeding the GPU.

Direct Access Forward Pass Time: 3.078 seconds
DataLoader Forward Pass Time: 8.575 seconds

The dataloader might be doing some extra work here unnecessary for this simple use-case. I can try and see if there is a way to make the dataloader not do that.

@jobs-git jobs-git changed the title [BUG] DataLoader is extremely slow compared to manual batching [BUG] DataLoader low GPU utilization and extremely slow compared to manual batching May 29, 2025
@jobs-git
Copy link
Author
jobs-git commented May 29, 2025

@jobs-git The batch size is quite large - but if that works for your pipeline thats great!

I tried playing around with your script. Using shuffle=False, num_workers=8, prefetch_factor=8, pin_memory=True and batch_X.to(device, non_blocking=True) I go the difference to ~3x. This is a good test for the dataloader, compared to just reading of a list/tensor - but the dataloader can do more like applying a custom collating function per batch which is comes very handy when you want to apply custom transformations on your data before feeding the GPU.

Direct Access Forward Pass Time: 3.078 seconds
DataLoader Forward Pass Time: 8.575 seconds

The dataloader might be doing some extra work here unnecessary for this simple use-case. I can try and see if there is a way to make the dataloader not do that.

Tried this, not bad for improvements, but still significantly slower in my test. About 600% slower, which is huge for some functionality that is not always needed. Also GPU utilization is very low @ 10-20%.

With bfloat16, the gap even widens to 50 times or 5000%!

Increasing num_worker has a disadvantage of consumer more ram than is actually needed with manual batching, sometimes it can cause unexpected bus error.

I like DataLoader as it simplifies batching, but the bloat, slowness and low GPU utilization is really an issue that hopefully could be resolved soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: data torch.utils.data module: dataloader Related to torch.utils.data.DataLoader and Sampler module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0