8000 [DCP] Allow for rank-specific tensors with duplicate keys · Issue #146566 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[DCP] Allow for rank-specific tensors with duplicate keys #146566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
cassanof opened this issue Feb 6, 2025 · 3 comments
Open

[DCP] Allow for rank-specific tensors with duplicate keys #146566

cassanof opened this issue Feb 6, 2025 · 3 comments
Labels
oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@cassanof
Copy link
Contributor
cassanof commented Feb 6, 2025

🚀 The feature, motivation and pitch

My understanding of DCP is that it assumes either DTensor, or fully replicated tensors in the state dict. I have some custom sharding implementation that doesn't use DTensor, and I needed to write a custom SavePlanner class that gathers the shard before saving.
The logic for loading is even uglier, as I need to modify the metadata object. For some other tensors, it's even worse because it's not clear how to gather them (e.g. torchao's TorchAOBaseTensor, used for AdamWFp8). I haven't found a workaround for this.
It would be great if there was an option to save a checkpoint with some tensors being specific to some ranks, that don't need to be gathered.

Alternatives

No response

Additional context

No response

cc @LucasLLC @pradeepfn @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@malfet malfet added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Feb 6, 2025
@yifuwang
Copy link
Collaborator

cc @fegin

@ekr0 ekr0 added oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. and removed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Feb 13, 2025
@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 14, 2025
@saumishr
Copy link
Contributor
saumishr commented Apr 8, 2025

@cassanof DCP makes no assumptions about the parallelization or replication. Once the state dict is provided, its saved using SPMD with dedupe and loaded with re-sharding if needed. Therefore whatever tensors are provided to the ranks will get saved accordingly. Both the options should work:

@cassanof
Copy link
Contributor Author
cassanof commented Apr 9, 2025

Hey @saumishr, thanks for the pointer.

Option #1 wouldn't work for us because the fully-replicated state won't fit in memory.

For option #2, I have previously built a custom save planner for this exact purpose. It all-gathers the shards before saving them. However, the main problem is loading, it's unclear to me how to do the reverse operation, i.e. create a load planner that loads only a view of the fully-replicated tensor.

Due to these issues, we decided to roll our own checkpointing logic, but would be great to come back to DCP if there is a workaround to this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants
0