8000 AllReduce hangs due to no device_id into init_process_group · Issue #142356 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
  • Notifications You must be signed in to change notification settings
  • < 8000 a icon="repo-forked" id="fork-button" href="/login?return_to=%2Fpytorch%2Fpytorch" rel="nofollow" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"repo details fork button","repository_id":65600975,"auth_type":"LOG_IN","originating_url":"https://github.com/pytorch/pytorch/issues/142356","user_id":null}}" data-hydro-click-hmac="f2eb90110777b7b205ad128a8272fddbab6609113bce9b0384cd960f5a77fd1d" data-view-component="true" class="btn-sm btn"> Fork 24.1k

AllReduce hangs due to no device_id into init_process_group #142356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
youth123 opened this issue Dec 9, 2024 · 6 comments
Open

AllReduce hangs due to no device_id into init_process_group #142356

youth123 opened this issue Dec 9, 2024 · 6 comments
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@youth123
Copy link
youth123 commented Dec 9, 2024

🐛 Describe the bug

In pytorch 2.2 or later, when initializing a distributed environment, if you do not pass device_id to the init_process_group function but pass it through the set_device, the allreduce primitive hangs or NCCL directly reports an error. This question only comes up on allreduce primitives, does anyone know why?
You can run the example below with torchrun --nproc_per_node 4 test.py

import torch
import torch.distributed as dist
import os

init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
world_size = 4

default_master_port = '6000'
master_port = os.getenv('MASTER_PORT', default_master_port)
init_method += master_ip + ':' + master_port
rank = int(os.getenv('RANK', '0'))
# correct init
# torch.distributed.init_process_group(
#     backend="nccl",
#     world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}"), init_method=init_method)

# will cause allreduce hang
torch.cuda.set_device(torch.device(f"cuda:{rank}"))
torch.distributed.init_process_group(
     backend="nccl",
     world_size=world_size, rank=rank, init_method=init_method)

cur_rank = torch.distributed.get_rank()
if cur_rank == 0 or cur_rank == 1:
    gqa_group = torch.distributed.new_group([0, 1])
else:
    gqa_group = torch.distributed.new_group([2, 3])

a = torch.tensor(1, device=cur_rank)
torch.distributed.all_reduce(a, group=gqa_group)

Versions

pytorch 2.5

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@malfet malfet added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Dec 9, 2024
@XilunWu
Copy link
Contributor
XilunWu commented Dec 9, 2024

I tried to reproduce the hang but couldn't. Instead, it's a NCCL error:

misc/socket.cc:484 NCCL WARN socketStartConnect: Connect to 2803:6082:c0b4:8500::1<35393> failed : Software caused connection abort
[rank1]: Traceback (most recent call last):
[rank1]:   File "test_allreduce.py", line 49, in <module>
[rank1]:     torch.distributed.all_reduce(a, group=gqa_group)
[rank1]:   File "torch/distributed/c10d_logger.py", line 81, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "torch/distributed/distributed_c10d.py", line 2773, in all_reduce
[rank1]:     work = group.allreduce([tensor], opts)
[rank1]: torch.distributed.DistBackendError: NCCL error in: torch/csrc/distributed/c10d/NCCLUtils.hpp:269, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
[rank1]: ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
[rank1]: Last error:
[rank1]: socketStartConnect: Connect to <host>::1<35393> failed : Software caused connection abort
W1209 11:00:32.851000 1935266 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1935505 closing signal SIGTERM
W1209 11:00:32.852000 1935266 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1935507 closing signal SIGTERM
W1209 11:00:32.853000 1935266 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1935508 closing signal SIGTERM

The test code I'm using:

import os

import torch
import torch.distributed as dist

from torch.testing._internal.common_utils import find_free_port

init_method = "tcp://"
master_ip = os.getenv("MASTER_ADDR", "localhost")
world_size = 4

default_master_port = find_free_port()
master_port = os.getenv("MASTER_PORT", default_master_port)
init_method += master_ip + ":" + master_port
rank = int(os.getenv("RANK", "0"))

# When `use_set_device == True`, the program has a good chance to fail with NCCL error.
use_set_device = True

if use_set_device:
    # will cause allreduce hang
    torch.cuda.set_device(torch.device(f"cuda:{rank}"))
    dist.init_process_group(
        backend="nccl",
        world_size=world_size,
        rank=rank,
        init_method=init_method,
    )
else:
    # correct init
    dist.init_process_group(
        backend="nccl",
        world_size=world_size,
        rank=rank,
        device_id=torch.device(f"cuda:{rank}"),
        init_method=init_method,
    )


# test code
cur_rank = torch.distributed.get_rank()
if cur_rank == 0 or cur_rank == 1:
    gqa_group = torch.distributed.new_group([0, 1])
else:
    gqa_group = torch.distributed.new_group([2, 3])

a = torch.tensor(1, device=cur_rank)
torch.distributed.all_reduce(a, group=gqa_group)

torch.distributed.destroy_process_group()

cc @yifuwang @kwen2501

@wconstab
Copy link
Contributor
wconstab commented Dec 9, 2024

iiuc this code is actually not correct. according to docs for new_group, all ranks must call new_group (for each group creation)
image

specifically, this code:

if cur_rank == 0 or cur_rank == 1:
    gqa_group = torch.distributed.new_group([0, 1])
else:
    gqa_group = torch.distributed.new_group([2, 3])

should change to this:

if cur_rank == 0 or cur_rank == 1:
    gqa_group = torch.distributed.new_group([0, 1])
    _ = torch.distributed.new_group([2, 3])
else:
    _ = torch.distributed.new_group([0, 1])
    gqa_group = torch.distributed.new_group([2, 3])

cc @kwen2501 please confirm and keep me honest

@youth123
Copy link
Author

I tried to reproduce the hang but couldn't. Instead, it's a NCCL error

The hanging phenomenon is not inevitable, and nccl's connection refused error will also occur.

@youth123
Copy link
Author

if cur_rank == 0 or cur_rank == 1:
gqa_group = torch.distributed.new_group([0, 1])
_ = torch.distributed.new_group([2, 3])
else:
_ = torch.distributed.new_group([0, 1])
gqa_group = torch.distributed.new_group([2, 3])

After I modified the code as above, I would occasionally get nccl connection refused error when running it multiple times. I think the error reported is not caused by new_group.

@wconstab
Copy link
Contributor

well, adding device_id into init_process_group opts you into 'eager init' for nccl. Without eager init, you get lazy init which means nccl establishes its connections on the first collective. In general, i've only seen these type of hangs when the collective or new group calls are not well synchronized.

the original code with incorrect new_group does cause the nccl connect error for me. But the corrected code works fine for me, can't repro an error after 5 retries back to back.

Can you confirm you're still seeing errors with this script?

@youth123
Copy link
Author

Can you confirm you're still seeing errors with this script?

I did see the error after trying it a dozen times. Here's my error log.

[rank3]:[E1211 09:42:50.285517203 ProcessGroupNCCL.cpp:542] [Rank 1] Collective WorkNCCL(SeqNum=1, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) raised the following async exception: NCCL error: unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error.
Last error:
Exception raised from checkForNCCLErrorsInternal at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2027 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fd8e1b6c446 in /root/.local/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::checkForNCCLErrorsInternal(std::shared_ptrc10d::NCCLComm&) + 0x220 (0x7fd897a29f70 in /root/.local/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::WorkNCCL::checkAndSetException() + 0x7c (0x7fd897a2a1bc in /root/.local/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::isStarted() + 0x90 (0x7fd897a2a490 in /root/.local/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::watchdogHandler() + 0x9f8 (0x7fd897a32368 in /root/.local/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fd897a3360d in /root/.local/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: + 0x145c0 (0x7fd8e20185c0 in /root/.local/lib/python3.10/site-packages/torch/lib/libtorch.so)
frame #7: + 0x94ac3 (0x7fd8e2869ac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #8: + 0x126850 (0x7fd8e28fb850 in /lib/x86_64-linux-gnu/libc.so.6)

@yf225 yf225 added the module: c10d Issues/PRs related to collective communications and process groups label Dec 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

5 participants
0