Open
Description
🚀 The feature, motivation and pitch
Backprop uses a lot of VRAM and can reach multiple factors the size of the model parameters and input data resulting to a lower GPU utilization. Computing resource and power savings maybe realized if instead we can optimize backprop VRAM usage.
Alternatives
Possible solutions:
- Backprop calculation by parts. So instead of calculating it as one big matrix, it could be possible to divide the chain matrix by batches, then clear batches that has already completed.
- Utilize async/unified memory for intermediate data. When backprop calculation exceeds the vram, the process is being send to OOM, instead, the excess data can be sent to CPU RAM via unified memory framework
Additional context
Code that demonstrate high VRAM usage
With backprop (6GB VRAM)
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(128, 1024*256)
self.fc2 = nn.Linear(1024*256, 1024)
self.fc3 = nn.Linear(1024, 16)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
torch.cuda.set_device(0)
torch.cuda.empty_cache()
model = SimpleModel().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
input_data = torch.randn(128, 128).pin_memory().cuda(non_blocking=True)
target_data = torch.randn(128, 16).pin_memory().cuda(non_blocking=True)
output = model(input_data)
loss = loss_fn(output, target_data)
loss.backward(retain_graph=False)
torch.cuda.synchronize()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
Without backprop (1.6GB VRAM)
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(128, 1024*256)
self.fc2 = nn.Linear(1024*256, 1024)
self.fc3 = nn.Linear(1024, 16)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
torch.cuda.set_device(0)
torch.cuda.empty_cache()
model = SimpleModel().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
input_data = torch.randn(128, 128).pin_memory().cuda(non_blocking=True)
target_data = torch.randn(128, 16).pin_memory().cuda(non_blocking=True)
output = model(input_data)
loss = loss_fn(output, target_data)
#loss.backward(retain_graph=False)
#torch.cuda.synchronize()
#optimizer.step()
#optimizer.zero_grad(set_to_none=True)
cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7 @xmfan
Metadata
Metadata
Assignees
Labels
Related to torch.autograd, and the autograd engine in generalPyTorch is using more memory than it should, or it is leaking memoryWe need to decide whether or not this merits inclusion, based on research worldThis issue has been looked at a team member, and triaged and prioritized into an appropriate module