8000 DDP doesn't work with retain_graph = True · Issue #47260 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
DDP doesn't work with retain_graph = True #47260
Open
@pritamdamania87

Description

@pritamdamania87

🐛 Bug

DDP doesn't work with retain_graph = True when trying to run backwards twice through the same model.

To Reproduce

To replicate, change only def demo_basic(rank, world_size) in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html 2 to the following:

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=1)
    
    optimizer.zero_grad()
    
    outputs = {}
    outputs['0'] = ddp_model(torch.rand(20, 10))
    outputs['1'] = ddp_model(torch.rand(20, 10))
    outputs['2'] = ddp_model(torch.rand(20, 10))

    labels = torch.rand(20, 5).to(rank)

    for i in range(3):
        print(f"before {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}")
                 
        if i < 2:
            loss_fn(outputs[str(i)], labels).backward(retain_graph=True)
        else:
            loss_fn(outputs[str(i)], labels).backward()

        print(f"after {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
                
    optimizer.step()
    print(f"last, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
    
    cleanup()

and the output is:

before 0, rank: 0, weight: 0.1450435221195221
before 0, rank: 3, weight: 0.1450435221195221
before 0, rank: 1, weight: 0.1450435221195221
before 0, rank: 2, weight: 0.1450435221195221
after 0, rank: 0, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 0, weight: 0.1450435221195221
after 0, rank: 3, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 1, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 2, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 1, weight: 0.1450435221195221
before 1, rank: 3, weight: 0.1450435221195221
before 1, rank: 2, weight: 0.1450435221195221
after 1, rank: 0, weight: 0.1450435221195221, grad: -0.03955963999032974
after 1, rank: 3, weight: 0.1450435221195221, grad: -0.03072114661335945
before 2, rank: 0, weight: 0.1450435221195221
before 2, rank: 3, weight: 0.1450435221195221
after 1, rank: 1, weight: 0.1450435221195221, grad: -0.03775426745414734
before 2, rank: 1, weight: 0.1450435221195221
after 1, rank: 2, weight: 0.1450435221195221, grad: -0.03235533833503723
before 2, rank: 2, weight: 0.1450435221195221
after 2, rank: 0, weight: 0.1450435221195221, grad: -0.06408560276031494
after 2, rank: 3, weight: 0.1450435221195221, grad: -0.04222358390688896
after 2, rank: 1, weight: 0.1450435221195221, grad: -0.056242190301418304
last, rank: 0, weight: 0.20912912487983704, grad: -0.06408560276031494
last, rank: 3, weight: 0.18726710975170135, grad: -0.04222358390688896
last, rank: 1, weight: 0.201285719871521, grad: -0.056242190301418304
after 2, rank: 2, weight: 0.1450435221195221, grad: -0.04413666948676109
last, rank: 2, weight: 0.1891801953315735, grad: -0.04413666948676109

Weights and grads do not seem to be synchronized.

Expected behavior

Grads should be synchronized.

Additional context

This seems to be a gap in DDP where it doesn’t support running backward twice? I couldn’t find any tests that use retain_graph=True with DDP.

The problem seems to be this line: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L530, which skips any gradient reduction. This variable is set to false after the first backward is done: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L1152 and then never set to True again since prepare_for_backward is not called anymore.

cc @ezyang @gchanan @zou3519 @bdhirsh @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @xush6528

Metadata

Metadata

Assignees

Labels

high priorityoncall: distributedAdd this issue/PR to distributed oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0