8000 [DCP] `set_model_state_dict` errors on compiled module with non-persistent buffer · Issue #122792 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[DCP] set_model_state_dict errors on compiled module with non-persistent buffer #122792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
awgu opened this issue Mar 27, 2024 · 2 comments
Closed
Assignees
Labels
triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@awgu
Copy link
Collaborator
awgu commented Mar 27, 2024
"""
torchrun --standalone --nproc_per_node=2 repro_dcp_compile.py
"""
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import get_model_state_dict, set_model_state_dict

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(4, 4)
        self.lin2 = nn.Linear(4, 4)
        self.register_buffer("buf", torch.randn((4,)), persistent=False)
        self.weight = nn.Parameter(torch.randn((4, 4)))


if __name__ == "__main__":
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)

    model = Model()
    model = torch.compile(model)

    sharded_sd = get_model_state_dict(model)
    set_model_state_dict(model, sharded_sd)
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/users/andgu/pytorch/repro_dcp_compile.py", line 36, in <module>
[rank0]:     set_model_state_dict(model, sharded_sd)
[rank0]:   File "/data/users/andgu/pytorch/torch/distributed/checkpoint/state_dict.py", line 853, in set_model_state_dict
[rank0]:     return _load_model_state_dict(model, model_state_dict, info)
[rank0]:   File "/data/users/andgu/pytorch/torch/distributed/checkpoint/state_dict.py", line 416, in _load_model_state_dict
[rank0]:     state_dict[fqn_with_prefix] = state_dict.pop(fqn)
[rank0]: KeyError: 'buf'

set_model_state_dict calls into _load_model_state_dict, which iterates over named_buffers(). For a compiled module, fqns and fqns_with_prefix always mismatch, so _load_model_state_dict will try to reassign from the FQN without prefix to the one with prefix. However, this does not account for non-persistent buffers not existing in the state dict.

One solution could be just to continue if fqn not in state_dict.

cc @LucasLLC

@awgu awgu added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 27, 2024
@wz337
Copy link
Contributor
wz337 commented May 7, 2024

Reopening this issue as the fix only correctly handle non_persistent buffers in eager mode, but there are still issues during for compiled.

@wz337 wz337 reopened this May 7, 2024
antoinebrl pushed a commit to antoinebrl/pytorch that referenced this issue May 27, 2024
…ytorch#125337)

Summary:
Fixes pytorch#122792

state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.

Pull Request resolved: pytorch#125337
Approved by: https://github.com/awgu
ghstack dependencies: pytorch#125333, pytorch#125501, pytorch#125334, pytorch#125335, pytorch#125336
huydhn pushed a commit that referenced this issue May 27, 2024
…125337) (#127219)

* [DSD] Fix to remove non_persistent buffer in distributed state dict (#125337)

Summary:
Fixes #122792

state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.

Pull Request resolved: #125337
Approved by: https://github.com/awgu
ghstack dependencies: #125333, #125501, #125334, #125335, #125336

* lintrunner

* lint

---------

Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
@fegin
7792 Copy link
Contributor
fegin commented May 31, 2024

@wz337 Do we still have the issue? I believe #125337 fixes the issue from DSD perspective. If the compiler cannot handle the non-persistent buffers, then it is unrelated to DSD.

@fegin fegin closed this as completed Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
0