Description
Activation checkpointing avoids saving intermediate tensors in order to save memory. It does so by recomputing the forward pass on demand to obtain the intermediate values required for gradient computation during backward.
For pipelining, we are splitting up the backward computation into stage_backward_input
and stage_backward_weight
(https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/_backward.py) which represent the derivative with respect to input and the derivative with respect to weights, respectively. This is mentioned in more detail in the ZeroBubble paper. We are computing backwards for each parameter group which leads to repeated forward computation that is impacting performance.
Minimal example
import torch
from torch.utils.checkpoint import checkpoint
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(1., requires_grad=True)
def fn(x, z):
print("computing forward")
y = x.sin().exp() * z
return y, y.sin().exp()
# compute forward under AC
y, out = checkpoint(fn, a, b, use_reentrant=False)
# 1. compute wrt y and x
da, dy = torch.autograd.grad(out, inputs=(a, y), retain_graph=True)
# 2. compute wrt z using y's grad
db = torch.autograd.grad((y,), inputs=(b,), grad_outputs=(dy,))
Output:
computing forward
computing forward
computing forward
torch.distributed.pipelining minimal example
import torch
from torch.distributed.pipelining._backward import stage_backward_input, stage_backward_weight
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
def example(checkpoint):
torch.manual_seed(0)
x = torch.tensor((1.,), requires_grad=True)
class SimpleModel(nn.Module):
def __init__(self, num_layers=5):
super().__init__()
self.layers = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)])
def forward(self, x):
print("computing forward")
return self.layers(x)
model = SimpleModel()
if checkpoint:
model = checkpoint_wrapper(model)
y = model(x)
loss = y.mean()
dx, param_groups = stage_backward_input([loss], output_grads=None, input_values=[x], weights=model.parameters())
print(f"{len(param_groups)=}")
stage_backward_weight(model.parameters(), param_groups=param_groups)
print("without AC:", "=" * 25)
example(checkpoint=False)
print("with AC:", "=" * 25)
example(checkpoint=True)
Output:
without AC: =========================
computing forward
len(param_groups)=5
with AC: =========================
computing forward
computing forward
len(param_groups)=5
computing forward
computing forward
computing forward
computing forward
computing forward
Potential solution
@soulitzer mentioned potentially adding a flag to AC to cache the recomputed intermediates. Still need discussion on how the API would look like for that and how we could explictly free the intermediates when we decide we are done.
For pipelining, saving the intermediates would increase the lifespan in memory for the time between backward_input
and backward_weight
, but this should still help peak memory. Ideally we could have a situation where forward only needs to be called once for backward_input
and once for backward_weight
.
cc @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @soulitzer