-
Notifications
You must be signed in to change notification settings - Fork 24.5k
8000
DDP doesn't work with retain_graph = True when trying to run backwards twice through the same model.
To replicate, change only def demo_basic(rank, world_size) in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html 2 to the following:
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=1)
optimizer.zero_grad()
outputs = {}
outputs['0'] = ddp_model(torch.rand(20, 10))
outputs['1'] = ddp_model(torch.rand(20, 10))
outputs['2'] = ddp_model(torch.rand(20, 10))
labels = torch.rand(20, 5).to(rank)
for i in range(3):
print(f"before {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}")
if i < 2:
loss_fn(outputs[str(i)], labels).backward(retain_graph=True)
else:
loss_fn(outputs[str(i)], labels).backward()
print(f"after {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
optimizer.step()
print(f"last, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
cleanup()
and the output is:
before 0, rank: 0, weight: 0.1450435221195221
before 0, rank: 3, weight: 0.1450435221195221
before 0, rank: 1, weight: 0.1450435221195221
before 0, rank: 2, weight: 0.1450435221195221
after 0, rank: 0, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 0, weight: 0.1450435221195221
after 0, rank: 3, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 1, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 2, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 1, weight: 0.1450435221195221
before 1, rank: 3, weight: 0.1450435221195221
before 1, rank: 2, weight: 0.1450435221195221
after 1, rank: 0, weight: 0.1450435221195221, grad: -0.03955963999032974
after 1, rank: 3, weight: 0.1450435221195221, grad: -0.03072114661335945
before 2, rank: 0, weight: 0.1450435221195221
before 2, rank: 3, weight: 0.1450435221195221
after 1, rank: 1, weight: 0.1450435221195221, grad: -0.03775426745414734
before 2, rank: 1, weight: 0.1450435221195221
after 1, rank: 2, weight: 0.1450435221195221, grad: -0.03235533833503723
before 2, rank: 2, weight: 0.1450435221195221
after 2, rank: 0, weight: 0.1450435221195221, grad: -0.06408560276031494
after 2, rank: 3, weight: 0.1450435221195221, grad: -0.04222358390688896
after 2, rank: 1, weight: 0.1450435221195221, grad: -0.056242190301418304
last, rank: 0, weight: 0.20912912487983704, grad: -0.06408560276031494
last, rank: 3, weight: 0.18726710975170135, grad: -0.04222358390688896
last, rank: 1, weight: 0.201285719871521, grad: -0.056242190301418304
after 2, rank: 2, weight: 0.1450435221195221, grad: -0.04413666948676109
last, rank: 2, weight: 0.1891801953315735, grad: -0.04413666948676109
Weights and grads do not seem to be synchronized.
Grads should be synchronized.
This seems to be a gap in DDP where it doesn’t support running backward twice? I couldn’t find any tests that use retain_graph=True with DDP.
The problem seems to be this line: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L530, which skips any gradient reduction. This variable is set to false after the first backward is done: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L1152 and then never set to True again since prepare_for_backward is not called anymore.
cc @ezyang @gchanan @zou3519 @bdhirsh @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @xush6528