8000 Mismatch of mixed precision `cast_fn` in FSDP and FSDP2 · Issue #153077 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Mismatch of mixed precision cast_fn in FSDP and FSDP2 #153077
Closed
@markovka17

Description

@markovka17

🐛 Describe the bug

FSDP2 does not work with dataclasses as input. More specifically, FSDP2's pre_hook does not cast tensors from dataclass. FSDP uses apply_to_tensors to handle dataclass-like objects. On the other hand, FSDP2 uses a simple _cast_fp_tensor

import dataclasses

import torch
from torch import nn

from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP


@dataclasses.dataclass
class Input:
    x: torch.Tensor


def main():
    class Model(nn.Module):

        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
            self._layer = nn.Linear(10, 10)

        def forward(self, input: Input):
            # RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
            return self._layer(input.x)

    # Example with FSDP2 (does not work !!!)
    # model = Model().cuda()
    # input = Input(torch.randn(2, 10).cuda())

    # fully_shard(model, mp_policy=MixedPrecisionPolicy(torch.bfloat16, torch.bfloat16, torch.bfloat16, True))
    # _ = model(input)

    # Example with FSDP
    model = Model().cuda()
    input = Input(torch.randn(2, 10).cuda())

    model = FSDP(model, mixed_precision=MixedPrecision(torch.bfloat16, torch.bfloat16, torch.bfloat16))
    _ = model(input)


if __name__ == "__main__":
    # dist.init_process_group(...)
    main()

Versions

PyTorch version: 2.6.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.1
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.4.210-39.1.pagevecsize-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.6.85
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB

Nvidia driver version: 560.35.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 230
On-line CPU(s) list: 0-229
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7702 64-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 1
Core(s) per socket: 1
Socket(s): 230
Stepping: 0
BogoMIPS: 4000.52
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr wbnoinvd arat npt nrip_save umip rdpid arch_capabilities
Virtualization: AMD-V
L1d cache: 14.4 MiB (230 instances)
L1i cache: 14.4 MiB (230 instances)
L2 cache: 115 MiB (230 instances)
L3 cache: 3.6 GiB (230 instances)
NUMA node(s): 4
NUMA node0 CPU(s): 0-56
NUMA node1 CPU(s): 57-113
NUMA node2 CPU(s): 114-170
NUMA node3 CPU(s): 171-229
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] flake8==5.0.4
[pip3] lovely-numpy==0.2.13
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cudnn-frontend==1.8.0
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] nvtx==0.2.5
[pip3] > [pip3] > [pip3] open-clip-torch==2.24.0
[pip3] optree==0.13.1
[pip3] pynvjitlink==0.3.0
[pip3] pytorch-lightning==2.0.2
[pip3] pytorch-triton==3.0.0+72734f086
[pip3] pytorch-warmup==0.1.1
[pip3] torch==2.6.0+cu126
[pip3] torch-fidelity==0.3.0
[pip3] torch_tensorrt==2.6.0a0
[pip3] torchao==0.9.0
[pip3] torchmetrics==1.0.3
[pip3] torchprofile==0.0.4
[pip3] torchsde==0.2.6
[pip3] torchvision==0.21.0+cu126
[pip3] triton==3.2.0
[conda] Could not collect

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0