8000 [DCP] failure case of save method · Issue #152310 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[DCP] failure case of save method #152310
Closed
Closed
@XuezheMax

Description

@XuezheMax

🐛 Describe the bug

#147675 fixed the issue of dcp gather_object. However, in broadcast_object, the similar bug is still there.

src=self.coordinator_rank,

I manually fixed this bug but still got the following failure case when saving FSDP model. In the following code sample, I created two FSDP models on 8 GPUs. The first one is on [0, 2, 4, 6], and the second one on [1, 3, 5, 7]. But when I use dcp.save to separately save this two models, the job hangs and failed.

command line:

NGPU=8; torchrun --nproc_per_node=$NGPU test_dcp.py

output

W0428 06:10:32.341000 1791284 site-packages/torch/distributed/run.py:766] *****************************************                                                                                                                                                                                                                                                         
W0428 06:10:32.341000 1791284 site-packages/torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.                                                                                  
W0428 06:10:32.341000 1791284 site-packages/torch/distributed/run.py:766] *****************************************                                                                                                                                                                                                                                                         
Run launched with torchrun, local rank: 3                                                                                                                                                                                                                                                                                                                                   
Run launched with torchrun, local rank: 1                                                                                                                                                                                                                                                                                                                                   
Run launched with torchrun, local rank: 7                                                                                                                                                                                                                                                                                                                                   
Run launched with torchrun, local rank: 2                                                                                                                                                                                                                                                                                                                                   
Run launched with torchrun, local rank: 0                                                                                                                                                                                                                                                                                                                                   
Run launched with torchrun, local rank: 4                                                                                                                                                                                                                                                                                                                                   
Run launched with torchrun, local rank: 5                                                                                                                                                                                                                                                                                                                                   
Run launched with torchrun, local rank: 6                                                                                                                                                                                                                                                                                                                                   

Global rank: 0 -- data parallel rank: 0/4 -- tensor parallel rank: 0/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        
Global rank: 2 -- data parallel rank: 1/4 -- tensor parallel rank: 0/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        
Global rank: 7 -- data parallel rank: 3/4 -- tensor parallel rank: 1/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        
Global rank: 6 -- data parallel rank: 3/4 -- tensor parallel rank: 0/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        
Global rank: 3 -- data parallel rank: 1/4 -- tensor parallel rank: 1/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        
Global rank: 4 -- data parallel rank: 2/4 -- tensor parallel rank: 0/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        
Global rank: 5 -- data parallel rank: 2/4 -- tensor parallel rank: 1/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        
Global rank: 1 -- data parallel rank: 0/4 -- tensor parallel rank: 1/2 -- context parallel rank: 0/1                                                                                                                                                                                                                                                                        

Global rank: 0 -- tensor parallel rank: 0/2 -- hsdp group: [0, 2, 4, 6]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
Global rank: 4 -- tensor parallel rank: 0/2 -- hsdp group: [0, 2, 4, 6]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
Global rank: 2 -- tensor parallel rank: 0/2 -- hsdp group: [0, 2, 4, 6]                                                                                                                                                                                                                                                                                                     
Global rank: 6 -- tensor parallel rank: 0/2 -- hsdp group: [0, 2, 4, 6]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
Global rank: 3 -- tensor parallel rank: 1/2 -- hsdp group: [1, 3, 5, 7]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
Global rank: 5 -- tensor parallel rank: 1/2 -- hsdp group: [1, 3, 5, 7]                                                                                                                                                                                                                                                                                                     
Global rank: 7 -- tensor parallel rank: 1/2 -- hsdp group: [1, 3, 5, 7]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
Global rank: 1 -- tensor parallel rank: 1/2 -- hsdp group: [1, 3, 5, 7]                                                                                                                                                                                                                                                                                                     

Global rank: 0 -- Checkpoint done.                                                                                                                                                                                                                                                                                                                                          
Global rank: 2 -- Checkpoint done.                                                                                                                                                                                                                                                                                                                                          
Global rank: 4 -- Checkpoint done.                                                                                                                                                                                                                                                                                                                                          
Global rank: 6 -- Checkpoint done.                                                                                                                                                                                                                                                                                                                                          

[rank3]: Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                 
[rank3]:   File "/lustrefs/users/[xuezhe.ma/projects/gecko/tests/test_dcp.py](http://xuezhe.ma/projects/gecko/tests/test_dcp.py)", line 272, in <module>                                                                                                                                                                                                                                                                         
[rank3]:     dcp.save(sharded_model_state, checkpoint_id=model_path, process_group=hsdp_group)                                                                                                                                                                                                                                                                              
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py)", line 87, in wrapper                                                                                                                                                                                                      
[rank3]:     result = func(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                 
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py)", line 465, in inner_func                                                                                                                                                                                                   
[rank3]:     return func(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                   
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py)", line 176, in save                                                                                                                                                                                              
[rank3]:     return _save_state_dict(                                                                                                                                                                                                                                                                                                                                       
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py)", line 350, in _save_state_dict                                                                                                                                                                                  
[rank3]:     central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)                                                                                                                                                                                                                                                                                 
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py)", line 196, in reduce_scatter                                                                                                                                                                                               
[rank3]:     all_data = self.gather_object(local_data)                                                                                                                                                                                                                                                                                                                      
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py)", line 135, in gather_object                                                                                                                                                                                                
[rank3]:     dist.gather_object(                                                                                                                                                                                                                                                                                                                                            
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/c10d_logger.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/c10d_logger.py)", line 81, in wrapper                                                                                                                                                                                                            
[rank3]:     return func(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                   
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py)", line 3153, in gather_object                                                                                                                                                                                               
[rank3]:     all_gather(object_size_list, local_size, group=group)                                                                                                                                                                                                                                                                                                          
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/c10d_logger.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/c10d_logger.py)", line 81, in wrapper                                                                                                                                                                                                            
[rank3]:     return func(*args, **kwargs)                                                                                                                                                  
10000
                                                                                                                                                                                 
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py)", line 3706, in all_gather                                                                                                                                                                                                  
[rank3]:     return handle_torch_function(                                                                                                                                                                                                                                                                                                                                  
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/overrides.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/overrides.py)", line 1721, in handle_torch_function                                                                                                                                                                                                          
[rank3]:     result = mode.__torch_function__(public_api, types, args, kwargs)                                                                                                                                                                                                                                                                                              
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/utils/_device.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/utils/_device.py)", line 104, in __torch_function__                                                                                                                                                                                                          
[rank3]:     return func(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                   
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/c10d_logger.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/c10d_logger.py)", line 81, in wrapper                                                                                                                                                                                                            
[rank3]:     return func(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                   
[rank3]:   File "/lustrefs/users/[xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py](http://xuezhe.ma/miniconda3/envs/gecko2.7.0_cu12.8/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py)", line 3728, in all_gather                                                                                                                                                                                                  
[rank3]:     work = group.allgather([tensor_list], [tensor])                                                                                                                                                                                                                                                                                                                
[rank3]: torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.cpp:77, remote process exited or there was a network error, NCCL version 2.26.2                                                                                                                                                                                  
[rank3]: ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.                                                                                                                                                                                                                                                            
[rank3]: Last error:                                                                                                                                                                                                                                                                                                                                                        
[rank3]: socketPollConnect: connect returned Connection refused, exceeded error retry count (35)
import sys
import math
import datetime
from typing import Dict, Tuple, Any
import logging
from pathlib import Path
import contextlib
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel
import torch.distributed.checkpoint as dcp
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
    StateDictOptions
)

logger = logging.getLogger()


@contextlib.contextmanager
def create_on_gpu():
    torch.set_default_device("cuda")
    try:
        yield
    finally:
        torch.set_default_device("cpu")


def initialize_logger() -> logging.Logger:
    # log everything
    logger = logging.getLogger()
    logger.setLevel(logging.NOTSET)

    # stdout: everything
    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setLevel(logging.NOTSET)

    # stderr: warnings / errors and above
    stderr_handler = logging.StreamHandler(sys.stderr)
    stderr_handler.setLevel(logging.WARNING)

    # set stream handlers
    logger.handlers.clear()
    assert len(logger.handlers) == 0, logger.handlers
    logger.handlers.append(stdout_handler)
    logger.handlers.append(stderr_handler)

    return logger


class Linear(nn.Module):

    def __init__(
        self,
        in_features: int,
        out_features: int,
        tensor_parallel_rank: int,
        tensor_parallel_size: int,
    ) -> None:
        super(Linear, self).__init__()

        # Keep input parameters
        self.in_features = in_features
        self.out_features = out_features
        # Divide the weight matrix along the last dimension.
        assert out_features % tensor_parallel_size == 0
        self.output_size_per_partition = out_features // tensor_parallel_size
        self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))

        # Initialize master weight
        master_weight = torch.empty(out_features, in_features, dtype=self.weight.dtype, requires_grad=False)
        nn.init.kaiming_normal_(master_weight, a=math.sqrt(5.0))

        # Split and copy
        weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
        my_weight = weight_list[tensor_parallel_rank]

        with torch.no_grad():
            self.weight.copy_(my_weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
        return F.linear(x, self.weight)


def init_torch_distributed(timeout: int = 1800) -> Tuple[int, int]:
    """
    Handle single and multi-GPU / multi-node / SLURM jobs.
    Initialize the following variables:
        - global_rank
        - world_size
    """
    global_rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])

    logger.info(f"Run launched with torchrun, local rank: {local_rank}")

    # set GPU device
    assert 0 <= local_rank < 8
    torch.cuda.set_device(local_rank)

    torch.distributed.init_process_group(
        init_method="env://",
        backend="nccl",
        timeout=datetime.timedelta(seconds=timeout),
    )

    assert global_rank == torch.distributed.get_rank()
    assert world_size == torch.distributed.get_world_size()

    # sanity check
    assert 0 <= local_rank <= global_rank < world_size

    return global_rank, world_size


def get_parallel_ranks(
    global_rank: int,
    tensor_parallel_size: int,
    context_parallel_size: int
) -> Tuple[int, int, int]:
    tensor_parallel_rank = global_rank % tensor_parallel_size
    global_rank = global_rank // tensor_parallel_size
    context_parallel_rank = global_rank % context_parallel_size
    data_parallel_rank = global_rank // context_parallel_size

    return data_parallel_rank, context_parallel_rank, tensor_parallel_rank


def get_device_mesh(
    data_parallel_size: int,
    context_parallel_size: int,
    tensor_parallel_size: int,
) -> DeviceMesh:
    world_size = torch.distributed.get_world_size()
    assert world_size == data_parallel_size * context_parallel_size * tensor_parallel_size
    if context_parallel_size == 1:
        if tensor_parallel_size == 1:
            mesh = torch.arange(world_size)
            device_mesh = DeviceMesh(
                device_type="cuda",
                mesh=mesh,
                mesh_dim_names=("dp",),
            )
        else:
            mesh = torch.arange(world_size).view(data_parallel_size, tensor_parallel_size)
            device_mesh = DeviceMesh(
                device_type="cuda",
                mesh=mesh,
                mesh_dim_names=("dp", "tp"),
            )["dp"]
    else:
        if tensor_parallel_size == 1:
            mesh = torch.arange(world_size).view(data_parallel_size, context_parallel_size)
            mesh = mesh.swapdims(0, 1)
            device_mesh = DeviceMesh(
                device_type="cuda",
                mesh=mesh,
                mesh_dim_names=("cp", "dp"),
            )
        else:
            mesh = torch.arange(world_size).view(data_parallel_size, context_parallel_size, tensor_parallel_size)
            mesh = mesh.swapdims(0, 1)
            device_mesh = DeviceMesh(
                device_type="cuda",
                mesh=mesh,
                mesh_dim_names=("cp", "dp", "tp"),
            )["cp", "dp"]

    return device_mesh


def build_model(
    in_features: int,
    out_features: int,
    data_parallel_size: int,
    tensor_parallel_rank: int,
    tensor_parallel_size: int,
    context_parallel_size: int
):
    if context_parallel_size == 1:
        sharding_strategy = ShardingStrategy.FULL_SHARD
    else:
        sharding_strategy = ShardingStrategy.HYBRID_SHARD

    device_mesh = get_device_mesh(data_parallel_size, context_parallel_size, tensor_parallel_size)
    mixed_precision = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

    fsdp_cfg = {
        "sharding_strategy": sharding_strategy,
        "mixed_precision": mixed_precision,
        "sync_module_states": False,
        "use_orig_params": False,  # flatten parameters
        "device_mesh": device_mesh
    }
    with create_on_gpu():
        with enable_wrap(wrapper_cls=FullyShardedDataParallel, **fsdp_cfg):
            model = Linear(in_features, out_features, tensor_parallel_rank, tensor_parallel_size)
            model = wrap(model.cuda())
            model.train()

        return model


def get_model_state(model, full_state: bool = False) -> Dict[str, Any]:
    state_dict_options = StateDictOptions(
        full_state_dict=full_state,
        cpu_offload=True,
    )
    return get_model_state_dict(model, options=state_dict_options)


def set_model_state(model, model_state_dict, full_state: bool = False):
    state_dict_options = StateDictOptions(
        full_state_dict=full_state,
        cpu_offload=True,
    )
    set_model_state_dict(model, model_state_dict, options=state_dict_options)


if __name__ == "__main__":
    seed = 42
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    initialize_logger()

    global_rank, world_size = init_torch_distributed()

    tensor_parallel_size = 2
    context_parallel_size = 1
    data_parallel_size = world_size // (tensor_parallel_size * context_parallel_size)

    dp_rank, cp_rank, tp_rank = get_parallel_ranks(global_rank, tensor_parallel_size, context_parallel_size)

    logger.info(
        f"Global rank: {global_rank} -- "
        f"data parallel rank: {dp_rank}/{data_parallel_size} -- "
        f"tensor parallel rank: {tp_rank}/{tensor_parallel_size} -- "
        f"context parallel rank: {cp_rank}/{context_parallel_size}"
    )

    in_features = 1024
    out_features = 4096

    model = build_model(in_features, out_features, data_parallel_size, tp_rank, tensor_parallel_size, context_parallel_size)

    x = torch.ones(2, in_features).cuda()
    y = model(x).sum()
    print(y)

    checkpoint_dir = Path("saved_models/ckpt")
    model_path = checkpoint_dir / f"sharded_model.tp{tp_rank:02d}"
    sharded_model_state = get_model_state(model, full_state=False)
    groups = torch.LongTensor(range(world_size)).reshape(-1, tensor_parallel_size)[:, tp_rank].tolist()
    hsdp_group = torch.distributed.new_group(groups)
    logger.info(
        f"Global rank: {global_rank} -- "
        f"tensor parallel rank: {tp_rank}/{tensor_parallel_size} -- "
        f"hsdp group: {groups}"
    )
    dcp.save(sharded_model_state, checkpoint_id=model_path, process_group=hsdp_group)

    logger.info(f"Global rank: {global_rank} -- Checkpoint done.")

    torch.distributed.destroy_process_group()

Versions

2.7.0

cc @LucasLLC @pradeepfn

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributed checkpointingOncall label should be attached to any issues related to distributed checkpointing.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0