8000 How to use fsdp2 cpu_offload? · Issue #1168 · pytorch/torchtitan · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

How to use fsdp2 cpu_offload? #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
KimmiShi opened this issue May 6, 2025 · 4 comments
Open

How to use fsdp2 cpu_offload? #1168

KimmiShi opened this issue May 6, 2025 · 4 comments

Comments

@KimmiShi
Copy link
KimmiShi commented May 6, 2025

I am currently using cpuOffloadPolicy in the following way:

      transformer_cls_to_wrap = list()
        for layer_class in transformer_cls_names_to_wrap:
            transformer_cls = get_module_class_from_name(model_to_wrap, layer_class)
            if transformer_cls is not None:
                transformer_cls_to_wrap.append(transformer_cls)
        if len(transformer_cls_to_wrap) == 0:
            raise NotImplementedError("len(transformer_cls_to_wrap) == 0, please check the wrapping rules!")
        mp_policy = MixedPrecisionPolicy(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
        )
        fsdp_kwargs = {
            "reshard_after_forward": True,
            "mp_policy": mp_policy,
            "offload_policy": CPUOffloadPolicy() if self.args.adam_offload else OffloadPolicy(),
        }

        for cls_to_wrap in transformer_cls_to_wrap:
            for module in model_to_wrap.modules():
                if isinstance(module, cls_to_wrap):
                    fully_shard(module, **fsdp_kwargs)
        for name, module in model_to_wrap.named_modules():
            if 'lm_head' in name:
                fully_shard(module, **fsdp_kwargs)

        fully_shard(model_to_wrap, **fsdp_kwargs)

        # cast model into fp32 to create optimizer with fp32 states
        # https://github.com/pytorch/torchtitan/issues/1133#issuecomment-2824429682
        model_to_wrap = model_to_wrap.to(torch.float32)

        if is_meta_initialized(model_to_wrap):
            model.to_empty(device='cuda')

        return model

The model is created from huggingface pretrained model, but I got the following error when doing clip_grad_norm:

grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.grad_clip)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/utils/clip_grad.py", line 30, in _no_grad_wrapper
[rank4]:     return func(*args, **kwargs)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/utils/clip_grad.py", line 105, in clip_grad_norm_
[rank4]:     clip_coef = max_norm / (total_norm + 1e-6)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/_tensor.py", line 39, in wrapped
[rank4]:     return f(*args, **kwargs)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/_tensor.py", line 1032, in __rdiv__
[rank4]:     return self.reciprocal() * other
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank4]:     return disable_fn(*args, **kwargs)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank4]:     return DTensor._op_dispatcher.dispatch(
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 181, in dispatch
[rank4]:     self.redistribute_local_args(
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args
[rank4]:     resharded_local_tensor = redistribute_local_tensor(
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/tensor/_redistribute.py", line 195, in redistribute_local_tensor
[rank4]:     new_local_tensor = partial_spec._reduce_value(
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 126, in _reduce_value
[rank4]:     reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/tensor/placement_types.py", line 599, in _reduce_value
[rank4]:     return funcol.all_reduce(
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/_functional_collectives.py", line 175, in all_reduce
[rank4]:     tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
[rank4]:   File "/root/miniconda3/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
[rank4]:     return self._op(*args, **(kwargs or {}))
[rank4]: RuntimeError: No backend type associated with device type cpu

Is there anything wrong in my model init device?

@tianyu-l
Copy link
Contributor
tianyu-l commented May 6, 2025

cc @mori360 @weifengpy

@fegin
Copy link
Contributor
fegin commented May 6, 2025

The gradients are place o 8000 n CPU when doing the collectives. @weifengpy What's the common practices when users need to manipulate gradients when CPU offload is on?

@weifengpy
Copy link
Contributor
weifengpy commented May 6, 2025

we need to init gloo at the beginning

torch.distributed.init_process_group(backend='cpu:gloo,cuda:nccl')

torchtitan does the same here:

torch.distributed.init_process_group(

@KimmiShi
Copy link
Author
KimmiShi commented May 7, 2025

@weifengpy Thanks a lot! So when CPU offload is enabled, all the gradients are also placed on CPU? The optimizer.step is also done on CPU?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants
0