Open
Description
🐛 Bug
torch.distributed.nn.all_reduce
computes different gradient values from torch.distributed.all_reduce
. In particular, it seems to scale the gradients by world_size
incorrectly.
To Reproduce
Script to reproduce the behavior:
import torch
import torch.distributed
import torch.distributed.nn
USE_NN_REDUCE = 1
def main_worker(gpu):
torch.distributed.init_process_group(
backend="nccl", init_method="tcp://localhost:12345", world_size=2, rank=gpu
)
torch.cuda.set_device(gpu)
x = torch.ones(1).cuda(device=gpu).requires_grad_()
xx = (gpu + 2) * x
if USE_NN_REDUCE:
print("nn.all_reduce")
xx = torch.distributed.nn.all_reduce(xx)
else:
print("regular all_reduce")
torch.distributed.all_reduce(xx)
print("Value after all_reduce:", xx)
xx.backward()
print(f"gpu={gpu}, grad={x.grad}")
if __name__ == "__main__":
torch.multiprocessing.spawn(main_worker, nprocs=2)
For torch.distributed.nn.all_reduce
(i.e. USE_NN_REDUCE=1), the gradient seems to be scaled by world_size
incorrectly:
Value after all_reduce: tensor([5.], device='cuda:0', grad_fn=<_AllReduceBackward>)
Value after all_reduce: tensor([5.], device='cuda:1', grad_fn=<_AllReduceBackward>)
gpu=0, grad=tensor([4.], device='cuda:0')
gpu=1, grad=tensor([6.], device='cuda:1')
For torch.distributed.all_reduce
, it is correct:
Value after all_reduce: tensor([5.], device='cuda:0', grad_fn=<MulBackward0>)
Value after all_reduce: tensor([5.], device='cuda:1', grad_fn=<MulBackward0>)
gpu=0, grad=tensor([2.], device='cuda:0')
gpu=1, grad=tensor([3.], device='cuda:1')
Expected behavior
torch.distributed.nn.all_reduce
should produce the correct gradient without scaling by world_size
.
Environment
PyTorch version: 1.8.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2
Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: TITAN RTX
GPU 1: TITAN RTX
GPU 2: TITAN RTX
GPU 3: TITAN RTX
GPU 4: TITAN RTX
GPU 5: TITAN RTX
GPU 6: TITAN RTX
GPU 7: TITAN RTX
Nvidia driver version: 450.66
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] pytorch-lightning==1.2.7
[pip3] torch==1.8.1
[pip3] torchmetrics==0.3.1
[pip3] torchvision==0.9.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] mkl 2020.2 256
[conda] mkl-include 2020.2 256
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.2.0 py37h23d657b_0
[conda] mkl_random 1.1.1 py37h0573a6f_0
[conda] numpy 1.18.5 pypi_0 pypi
[conda] pytorch 1.7.0 py3.7_cuda10.2.89_cudnn7.6.5_0 pytorch
[conda] pytorch-lightning 1.1.8 pypi_0 pypi
[conda] torchvision 0.8.1 py37_cu102 pytorch
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23