Description
🐛 Describe the bug
I'm seeing an issue with FSDP2 (i.e. fully_shard
) where calling model.forward()
twice in a row, without calling loss.backward()
inbetween throws a the mixed torch.Tensor and DTensor error I've pasted below.
I'm not able to call loss.backward()
inbetween since I need the output of both forward passes before I can compute the loss. I also can't batch the inputs together because the model will OOM.
I'm also using activation checkpointing which could be related to the issue. What I think may be happening is the parameters are not allgathered correctly when recomputing one of the forward passes.
Is there some fix or workaround we could use for this case? Or does FSDP2 require a bwd before calling the next fwd?
[rank4]: Traceback (most recent call last):
[rank4]: File "/traindata/eugen/torchtune/deepseek_v3/dpo.py", line 417, in <module>
[rank4]: run_dpo(config)
[rank4]: File "/traindata/eugen/torchtune/deepseek_v3/dpo.py", line 269, in run_dpo
[rank4]: loss.backward()
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/_tensor.py", line 648, in backward
[rank4]: torch.autograd.backward(
[rank4]:
989C
File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/autograd/__init__.py", line 354, in backward
[rank4]: _engine_run_backward(
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1123, in unpack_hook
[rank4]: frame.recompute_fn(*args)
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1513, in recompute_fn
[rank4]: fn(*args, **kwargs)
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank4]: return self._call_impl(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1861, in _call_impl
[rank4]: return inner()
[rank4]: ^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1809, in inner
[rank4]: result = forward_call(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/deepseek_v3/components/deepseek_model.py", line 788, in forward
[rank4]: hidden_states = self.mlp(hidden_states)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank4]: return self._call_impl(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank4]: return forward_call(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/deepseek_v3/components/deepseek_model.py", line 477, in forward
[rank4]: y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/deepseek_v3/components/deepseek_model.py", line 539, in moe_forward
[rank4]: processed_tokens[gatherd_idxs == i] = expert(
[rank4]: ^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank4]: return self._call_impl(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1861, in _call_impl
[rank4]: return inner()
[rank4]: ^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1809, in inner
[rank4]: result = forward_call(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/deepseek_v3/components/deepseek_model.py", line 341, in forward
[rank4]: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
[rank4]: ^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank4]: return self._call_impl(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank4]: return forward_call(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward
[rank4]: return F.linear(input, self.weight, self.bias)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
[rank4]: return disable_fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 857, in _fn
[rank4]: return fn(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank4]: return DTensor._op_dispatcher.dispatch(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 157, in dispatch
[rank4]: op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 356, in unwrap_to_op_info
[rank4]: self._try_replicate_spec_for_scalar_tensor(
[rank4]: File "/traindata/eugen/torchtune/.venv/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 458, in _try_replicate_spec_for_scalar_tensor
[rank4]: raise RuntimeError(
[rank4]: RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
Versions
PyTorch version: 2.8.0.dev20250506+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.0.0
Libc version: glibc-2.35
Python version: 3.12.8 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:31:09) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1081-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H200
GPU 1: NVIDIA H200
GPU 2: NVIDIA H200
GPU 3: NVIDIA H200
GPU 4: NVIDIA H200
GPU 5: NVIDIA H200
GPU 6: NVIDIA H200
GPU 7: NVIDIA H200
Nvidia driver version: 550.163.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8488C
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
Stepping: 8
BogoMIPS: 4800.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd ida arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear serialize amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 4.5 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 192 MiB (96 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flashinfer-python==0.2.5+cu124torch2.6
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.8.0.87
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250506+cu128
[pip3] torchao==0.10.0
[pip3] torchaudio==2.6.0
[pip3] torchtune==0.6.0
[pip3] torchvision==0.21.0
[pip3] triton==3.2.0
[conda] No relevant packages
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k