8000 [C10D] Blackwell B200: NotImplementedError: Could not run '_c10d_functional_autograd::all_to_all_single' with arguments from the 'CUDA' backend. · Issue #154370 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[C10D] Blackwell B200: NotImplementedError: Could not run '_c10d_functional_autograd::all_to_all_single' with arguments from the 'CUDA' backend. #154370
Open
@lessw2020

Description

@lessw2020

🐛 Describe the bug

Running TorchTitan with deepseek v2 in inference on B200. Normally we use symmetric memory api for comms but this fails due to lack of Triton support for Blackwell (at least in pytorch).
Switched to dist.all2all as fallback but get:
"NotImplementedError: Could not run '_c10d_functional_autograd::all_to_all_single' with arguments from the 'CUDA' backend. "

Slightly odd as I have put the model in eval mode but potentially something with grad=true in there.

Repro:
on blackwell - disable symmetric memory in generate.py under titan/experiments/deepseek_v3.

# model.setup_symm_mem(torch.bfloat16, dist_config.device)

run inference via bash inference.sh "your prompt here"
see the Not Impl error.

Here is the full trace:

Running inference with deepseek-ai/DeepSeek-V2-Lite-Chat on (1, 4) mesh
Creating model stage 0 of 1
Creating model stage 0 of 1
Creating model stage 0 of 1
Creating model stage 0 of 1
Generating: NCCL version 2.26.5+cuda12.9
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/generate.py", line 377, in <module>
[rank0]:     generate(model, pp_schedule, tokenizer, dist_config, messages)
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/generate.py", line 198, in wrapper
[rank0]:     result, tokens_generated = func(*args, **kwargs)
[rank0]:                                ^^^^^^^^^^^^^^^^^^
9B45
^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/generate.py", line 281, in generate
[rank0]:     preds = model(x)
[rank0]:             ^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/model.py", line 1322, in forward
[rank0]:     hidden_states = self.model(
[rank0]:                     ^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/model.py", line 1274, in forward
[rank0]:     hidden_states = decoder_layer(
[rank0]:                     ^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/model.py", line 1123, in forward
[rank0]:     hidden_states = self.mlp(hidden_states)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/model.py", line 642, in forward
[rank0]:     y = self.moe_forward(hidden_states, topk_idx, topk_weight)
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/less/torchtitan/torchtitan/experiments/deepseek_v3/model.py", line 699, in moe_forward
[rank0]:     gathered_tokens = all_to_all_single_autograd(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py", line 527, in all_to_all_single_autograd
[rank0]:     tensor = torch.ops._c10d_functional_autograd.all_to_all_single(  # type: ignore[attr-defined]
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/less/.conda/envs/pycutlass/lib/python3.12/site-packages/torch/_ops.py", line 1208, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: NotImplementedError: Could not run '_c10d_functional_autograd::all_to_all_single' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. '_c10d_functional_autograd::all_to_all_single' is only available for these backends: [Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradMAIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastMTIA, AutocastMAIA, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

[rank0]: Meta: registered at /pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
[rank0]: BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
[rank0]: Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
[rank0]: FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:479 [backend fallback]
[rank0]: Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:349 [backend fallback]
[rank0]: Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
[rank0]: Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
[rank0]: Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
[rank0]: ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
[rank0]: ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:104 [backend fallback]
[rank0]: AutogradOther: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradCPU: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradCUDA: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradHIP: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradXLA: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradMPS: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradIPU: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradXPU: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradHPU: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradVE: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradLazy: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradMTIA: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradMAIA: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradPrivateUse1: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradPrivateUse2: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradPrivateUse3: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradMeta: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: AutogradNestedTensor: registered at /pytorch/torch/csrc/distributed/c10d/Functional.cpp:502 [autograd kernel]
[rank0]: Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:294 [backend fallback]
[rank0]: AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:322 [backend fallback]
[rank0]: AutocastMTIA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:466 [backend fallback]
[rank0]: AutocastMAIA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:504 [backend fallback]
[rank0]: AutocastXPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:542 [backend fallback]
[rank0]: AutocastMPS: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
[rank0]: AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
[rank0]: FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
[rank0]: BatchedNestedTensor: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
[rank0]: FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
[rank0]: Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
[rank0]: VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
[rank0]: FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:208 [backend fallback]
[rank0]: PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
[rank0]: FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:475 [backend fallback]
[rank0]: PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
[rank0]: PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]

Versions

PyTorch version: 2.8.0.dev20250515+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-5)
Clang version: Could not collect
CMake version: version 3.26.5
Libc version: glibc-2.34

Python version: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.9.0-0_fbk7_0_g44ab11142beb-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA B200
GPU 1: NVIDIA B200
GPU 2: NVIDIA B200
GPU 3: NVIDIA B200
GPU 4: NVIDIA B200
GPU 5: NVIDIA B200
GPU 6: NVIDIA B200
GPU 7: NVIDIA B200

Nvidia driver version: 570.124.06
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.9.8.0
/usr/lib64/libcudnn_adv.so.9.8.0
/usr/lib64/libcudnn_cnn.so.9.8.0
/usr/lib64/libcudnn_engines_precompiled.so.9.8.0
/usr/lib64/libcudnn_engines_runtime_compiled.so.9.8.0
/usr/lib64/libcudnn_graph.so.9.8.0
/usr/lib64/libcudnn_heuristic.so.9.8.0
/usr/lib64/libcudnn_ops.so.9.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 384
On-line CPU(s) list: 0-383
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 74%
CPU max MHz: 3707.8120
CPU min MHz: 1500.0000
BogoMIPS: 4792.43
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 6 MiB (192 instances)
L1i cache: 6 MiB (192 instances)
L2 cache: 192 MiB (192 instances)
L3 cache: 768 MiB (24 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-95,192-287
NUMA node1 CPU(s): 96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.8.0.87
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.5
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250515+cu128
[pip3] torchao==0.12.0.dev20250516+cu128
[pip3] torchdata==0.11.0
[pip3] torchtitan==0.0.2
[conda] numpy 2.2.5 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.3.14 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.8.0.87 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.41 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.55 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.2.55 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.7.53 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.55 pypi_0 pypi
[conda] pytorch-triton 3.3.0+git96316ce5 pypi_0 pypi
[conda] torch 2.8.0.dev20250515+cu128 pypi_0 pypi
[conda] torchao 0.12.0.dev20250516+cu128 pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] torchtitan 0.0.2 pypi_0 pypi

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0