8000 [Feature Request] Memory optimization for backward propagation in GPU · Issue #150698 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[Feature Request] Memory optimization for backward propagation in GPU #150698
Open
@jobs-git

Description

@jobs-git

🚀 The feature, motivation and pitch

Backprop uses a lot of VRAM and can reach multiple factors the size of the model parameters and input data resulting to a lower GPU utilization. Computing resource and power savings maybe realized if instead we can optimize backprop VRAM usage.

Alternatives

Possible solutions:

  • Backprop calculation by parts. So instead of calculating it as one big matrix, it could be possible to divide the chain matrix by batches, then clear batches that has already completed.
  • Utilize async/unified memory for intermediate data. When backprop calculation exceeds the vram, the process is being send to OOM, instead, the excess data can be sent to CPU RAM via unified memory framework

Additional context

Code that demonstrate high VRAM usage

With backprop (6GB VRAM)

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(128, 1024*256)
        self.fc2 = nn.Linear(1024*256, 1024)
        self.fc3 = nn.Linear(1024, 16)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

torch.cuda.set_device(0)
torch.cuda.empty_cache()

model = SimpleModel().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

input_data = torch.randn(128, 128).pin_memory().cuda(non_blocking=True)
target_data = torch.randn(128, 16).pin_memory().cuda(non_blocking=True)

output = model(input_data)
loss = loss_fn(output, target_data)

loss.backward(retain_graph=False)
torch.cuda.synchronize()

optimizer.step()
optimizer.zero_grad(set_to_none=True)

Without backprop (1.6GB VRAM)

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(128, 1024*256)
        self.fc2 = nn.Linear(1024*256, 1024)
        self.fc3 = nn.Linear(1024, 16)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

torch.cuda.set_device(0)
torch.cuda.empty_cache()

model = SimpleModel().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

input_data = torch.randn(128, 128).pin_memory().cuda(non_blocking=True)
target_data = torch.randn(128, 16).pin_memory().cuda(non_blocking=True)

output = model(input_data)
loss = loss_fn(output, target_data)

#loss.backward(retain_graph=False)
#torch.cuda.synchronize()

#optimizer.step()
#optimizer.zero_grad(set_to_none=True)

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7 @xmfan

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: memory usagePyTorch is using more memory than it should, or it is leaking memoryneeds researchWe need to decide whether or not this merits inclusion, based on research worldtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0