8000 Activation Checkpointing composability with split backward computation · Issue #145511 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Activation Checkpointing composability with split backward computation #145511
Open
@H-Huang

Description

@H-Huang

Activation checkpointing avoids saving intermediate tensors in order to save memory. It does so by recomputing the forward pass on demand to obtain the intermediate values required for gradient computation during backward.

For pipelining, we are splitting up the backward computation into stage_backward_input and stage_backward_weight (https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/_backward.py) which represent the derivative with respect to input and the derivative with respect to weights, respectively. This is mentioned in more detail in the ZeroBubble paper. We are computing backwards for each parameter group which leads to repeated forward computation that is impacting performance.

Minimal example

import torch
from torch.utils.checkpoint import checkpoint

a = torch.tensor(1., requires_grad=True)
b = torch.tensor(1., requires_grad=True)

def fn(x, z):
    print("computing forward")
    y = x.sin().exp() * z
    return y, y.sin().exp()

# compute forward under AC
y, out = checkpoint(fn, a, b, use_reentrant=False)

# 1. compute wrt y and x
da, dy = torch.autograd.grad(out, inputs=(a, y), retain_graph=True)

# 2. compute wrt z using y's grad
db = torch.autograd.grad((y,), inputs=(b,), grad_outputs=(dy,))

Output:

computing forward
computing forward
computing forward

torch.distributed.pipelining minimal example

import torch
from torch.distributed.pipelining._backward import stage_backward_input, stage_backward_weight
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper


def example(checkpoint):
    torch.manual_seed(0)
    x = torch.tensor((1.,), requires_grad=True)
    class SimpleModel(nn.Module):
        def __init__(self, num_layers=5):
            super().__init__()
            self.layers = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)])

        def forward(self, x):
            print("computing forward")
            return self.layers(x)

    model = SimpleModel()
    if checkpoint:
        model = checkpoint_wrapper(model)

    y = model(x)
    loss = y.mean()

    dx, param_groups = stage_backward_input([loss], output_grads=None, input_values=[x], weights=model.parameters())
    print(f"{len(param_groups)=}")
    stage_backward_weight(model.parameters(), param_groups=param_groups)

print("without AC:", "=" * 25)
example(checkpoint=False)
print("with AC:", "=" * 25)
example(checkpoint=True)

Output:

without AC: =========================
computing forward
len(param_groups)=5
with AC: =========================
computing forward
computing forward
len(param_groups)=5
computing forward
computing forward
computing forward
computing forward
computing forward

Potential solution

@soulitzer mentioned potentially adding a flag to AC to cache the recomputed intermediates. Still need discussion on how the API would look like for that and how we could explictly free the intermediates when we decide we are done.

For pipelining, saving the intermediates would increase the lifespan in memory for the time between backward_input and backward_weight, but this should still help peak memory. Ideally we could have a situation where forward only needs to be called once for backward_input and once for backward_weight.

cc @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @soulitzer

Metadata

Metadata

Assignees

Labels

module: activation checkpointingRelated to activation checkpointingmodule: pipeliningPipeline Parallelismoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0