10000 torch.distributed.nn.all_reduce incorrectly scales the gradient · Issue #58005 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
torch.distributed.nn.all_reduce incorrectly scales the gradient #58005
Open
@DrJimFan

Description

@DrJimFan

🐛 Bug

torch.distributed.nn.all_reduce computes different gradient values from torch.distributed.all_reduce. In particular, it seems to scale the gradients by world_size incorrectly.

To Reproduce

Script to reproduce the behavior:

import torch
import torch.distributed
import torch.distributed.nn

USE_NN_REDUCE = 1

def main_worker(gpu):
    torch.distributed.init_process_group(
        backend="nccl", init_method="tcp://localhost:12345", world_size=2, rank=gpu
    )
    torch.cuda.set_device(gpu)
    x = torch.ones(1).cuda(device=gpu).requires_grad_()
    xx = (gpu + 2) * x
    if USE_NN_REDUCE:
        print("nn.all_reduce")
        xx = torch.distributed.nn.all_reduce(xx)
    else:
        print("regular all_reduce")
        torch.distributed.all_reduce(xx)
    print("Value after all_reduce:", xx)
    xx.backward()
    print(f"gpu={gpu}, grad={x.grad}")

if __name__ == "__main__":
    torch.multiprocessing.spawn(main_worker, nprocs=2)

For torch.distributed.nn.all_reduce (i.e. USE_NN_REDUCE=1), the gradient seems to be scaled by world_size incorrectly:

Value after all_reduce: tensor([5.], device='cuda:0', grad_fn=<_AllReduceBackward>)
Value after all_reduce: tensor([5.], device='cuda:1', grad_fn=<_AllReduceBackward>)
gpu=0, grad=tensor([4.], device='cuda:0')
gpu=1, grad=tensor([6.], device='cuda:1')

For torch.distributed.all_reduce, it is correct:

Value after all_reduce: tensor([5.], device='cuda:0', grad_fn=<MulBackward0>)
Value after all_reduce: tensor([5.], device='cuda:1', grad_fn=<MulBackward0>)
gpu=0, grad=tensor([2.], device='cuda:0')
gpu=1, grad=tensor([3.], device='cuda:1')

Expected behavior

torch.distributed.nn.all_reduce should produce the correct gradient without scaling by world_size.

Environment

PyTorch version: 1.8.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: TITAN RTX
GPU 1: TITAN RTX
GPU 2: TITAN RTX
GPU 3: TITAN RTX
GPU 4: TITAN RTX
GPU 5: TITAN RTX
GPU 6: TITAN RTX
GPU 7: TITAN RTX

Nvidia driver version: 450.66
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] pytorch-lightning==1.2.7
[pip3] torch==1.8.1
[pip3] torchmetrics==0.3.1
[pip3] torchvision==0.9.1
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.2.89              hfd86e86_1
[conda] mkl                       2020.2                      256
[conda] mkl-include               2020.2                      256
[conda] mkl-service               2.3.0            py37he904b0f_0
[conda] mkl_fft                   1.2.0            py37h23d657b_0
[conda] mkl_random                1.1.1            py37h0573a6f_0
[conda] numpy                     1.18.5                   pypi_0    pypi
[conda] pytorch                   1.7.0           py3.7_cuda10.2.89_cudnn7.6.5_0    pytorch
[conda] pytorch-lightning         1.1.8                    pypi_0    pypi
[conda] torchvision               0.8.1                py37_cu102    pytorch

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23

Metadata

Metadata

Assignees

No one assigned

    Labels

    high priorityoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0