-
Notifications
You must be signed in to change notification settings - < 8000 a icon="repo-forked" id="fork-button" href="/login?return_to=%2Fpytorch%2Fpytorch" rel="nofollow" data-hydro-click="{"event_type":"authentication.click","payload":{"location_in_page":"repo details fork button","repository_id":65600975,"auth_type":"LOG_IN","originating_url":"https://github.com/pytorch/pytorch/issues/142356","user_id":null}}" data-hydro-click-hmac="f2eb90110777b7b205ad128a8272fddbab6609113bce9b0384cd960f5a77fd1d" data-view-component="true" class="btn-sm btn"> Fork 24.1k
AllReduce hangs due to no device_id into init_process_group #142356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I tried to reproduce the hang but couldn't. Instead, it's a NCCL error:
The test code I'm using:
|
iiuc this code is actually not correct. according to docs for new_group, all ranks must call new_group (for each group creation) specifically, this code:
should change to this:
cc @kwen2501 please confirm and keep me honest |
The hanging phenomenon is not inevitable, and nccl's connection refused error will also occur. |
After I modified the code as above, I would occasionally get nccl connection refused error when running it multiple times. I think the error reported is not caused by new_group. |
well, adding device_id into init_process_group opts you into 'eager init' for nccl. Without eager init, you get lazy init which means nccl establishes its connections on the first collective. In general, i've only seen these type of hangs when the collective or new group calls are not well synchronized. the original code with incorrect new_group does cause the nccl connect error for me. But the corrected code works fine for me, can't repro an error after 5 retries back to back. Can you confirm you're still seeing errors with this script? |
I did see the error after trying it a dozen times. Here's my error log.
|
🐛 Describe the bug
In pytorch 2.2 or later, when initializing a distributed environment, if you do not pass device_id to the init_process_group function but pass it through the set_device, the allreduce primitive hangs or NCCL directly reports an error. This question only comes up on allreduce primitives, does anyone know why?
You can run the example below with
torchrun --nproc_per_node 4 test.py
Versions
pytorch 2.5
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
The text was updated successfully, but these errors were encountered: