-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[FSDP] FSDP with CPU offload consumes 1.65X
more GPU memory when training models with most of the params frozen
#91165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks for providing so many details! This helps us out a lot. We may need more time to investigate why CPU offload consumes so much more GPU memory. For now, have you guys tried setting IIUC, CPU offload was required for the 1T parameter training run (cc: @rohan-varma @zhaojuanmao), so the memory savings were proven there. Conceptually, I am not seeing how frozen parameters/layers affect the GPU memory usage aside from the optimizer state as you mentioned. In that case, if you think it is helpful, could you verify what optimizer states are being created when using FSDP? Are you seeing that optimizer states are being created even for |
Hello @awgu, thanks for your quick reply. Apart from the optimizer states only for only the tiny portion of trainable params, offloading to CPU should further reduce GPU memory usage. In case of CPU offloading, the expected behaviour would be to only move model shards (T5 block one at a time as they belong to separate FSDP units) from CPU to GPU, do the forward pass computation and then GPU to CPU (similar to B/w pass), hence the max memory consumption on GPU should be that of a single T5 block only. Let me know if there are gaps in my understanding. |
@pacman100 Do you have a recommendation for how to try to reproduce this? |
@awgu From the source code, it seems that another problem of FSDP is that the freezed parameters are not deallocated until the end of the backward. If most parameters are frozen in this case, then the GPU would have a nearly whole copy of the model parameters. |
This is unfortunately a known problem :( We would need to add special hooks to reshard the frozen parameters after they are used for gradient computation. We have not had any bandwidth to look into this though. |
Is it possible to use a |
@hscspring I do not think so. The issue is more fundamental than simply choosing the wrapping policy and comes from when a parameter corresponds to a module whose output activations require gradient but itself does not require gradient -- in that case, FSDP's pre-backward hook runs but its post-backward hook does not. |
@awgu thanks for the quick reply. u r right, wrapping policy only effects how to shard the model. |
@awgu any idea about how to add the hook? e.g., torch.autograd.Function. |
@pengyanghua I am testing #101982 out. I am seeing memory savings in my experiments, but @HamidShojanazeri is not at the moment. We are investigating. |
@awgu , just to update the thread as we discussed offline, I could also observe the memory savings as well. |
Since #101982 landed, I would recommend people try things out again (e.g. with a nightly). There should be memory savings now. |
Thank you @awgu and @HamidShojanazeri. Will be trying this out this week and report back. This should work with CPU offloading too, right? |
@pacman100 Yes, it should work with CPU offloading. (It turns out that CPU offloading is connected to "freeing the parameters", so once we fixed "freeing the parameters", the parameters should be offloaded to CPU as well if enabled.) |
Tried the latest nightly '2.1.0.dev20230620' Using 2 A100 GPUs, able to run
As I stated earlier, the expected behaviour is
Steps to reproduce:
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
"sync_module_states": fsdp_plugin.sync_module_states,
"backward_prefetch": fsdp_plugin.backward_prefetch,
"forward_prefetch": fsdp_plugin.forward_prefetch,
"use_orig_params": fsdp_plugin.use_orig_params,
"param_init_fn": fsdp_plugin.param_init_fn,
"ignored_modules": fsdp_plugin.ignored_modules,
- "ignored_parameters": fsdp_plugin.ignored_parameters,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": self.device,
}
model = FSDP(model, **kwargs)
FSDP config:
Code
output logs:
|
@pacman100 Can you please share the manually warp code for lora? I encounter the same problem. Many Thanks! |
@pacman100 Is this issue solved in recent pytorch versions? |
Still having OOM issue with CPU offload on a 7b model in mixed precision |
Hello, come across this important discussion. In the latest PyTorch Nightlies (for example in the LLaMA-receipes implementation), runing PEFT with FSDP indeeds saves memory. For example, when using a mini-batch size of 4 to train 13B model on 4 GPU (with pure BF16), VRAM per GPU is 25G (PEFT) vs 55g (full model). Could we confirm this issue has resolved in the latest PyTorch Nightlies? Many thanks. Also note to future readers, there are some excellent discussion on this led by @awgu: Also looking at LLaMA-recipes implementation, seems they largely followed the rec from this post. They will use the customized lambda_policy_fn only when using peft. |
@awgu There's a lot of issues open regarding FSDP + PEFT. I've been looking at a lot of these discussion threads and unfortunately even with CPU Offloading and a wrapping policy suggested here, I'm still facing persistent OOM issues. This was another issue we'd like to hear updates on. Is it possible to consolidate everything regarding this topic (current workarounds, PT nightlies, current problems, WIP if any) either on this thread or in a new issue on this repo? That would really help a lot of us currently struggling to evade OOMs with FSDP + PEFT. |
🐛 Describe the bug
Context: We have more and more situations where a large part of the model that's being trained is frozen. As these are very large LLMs, we want to leverage FSDP with CPU offloading to fit such large model training with only a tiny fraction of training params on consumer GPUs. To this end, below is an example of finetuning
bigscience/mt0-small
using LoRA parameter efficient fine-tuning method with/without FSDP.To fine-tune with FSDP:
AssertionError: expects all parameters to have same requires_grad
, created a customauto_wrap_policy
such that the layers with trainable params are in separate FSDP units than those which are frozen. The result model along with FSDP options are given below. We are using Accelerate's FSDP integration . the trainable params havelora_
prefix:Plain PyTorch
:FSDP Full Shard with CPU offloading
:When trying to use FSDP with CPU offloading using
bigscience/mt0-xxl
model (13B params) on a A100 80GB GPU it results in OOM GPU error whereas Plain Pytorch consumes 56GB GPU memory.Expected behaviour: Efficiently deal with frozen weights during training such that large models could be offloaded on CPUs/sharded across GPUs properly with storage of the optimizer state only for the trainable parameters, e.g., we can see that using plain PyTorch, mt0-xxl (13B params) model takes up 56GB on GPU, now, it would be really helpful if one could do CPU offloading such that training could work on a 16GB or 24GB GPU using FSDP with CPU offloading.
Versions
Collecting environment information...
PyTorch version: 1.14.0.dev20221117+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31
Python version: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.23.4
[pip3] torch==1.14.0.dev20221117+cu117
[pip3] torchaudio==0.14.0.dev20221117
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.15.0.dev20221117
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] msgpack-numpy 0.4.7.1 pypi_0 pypi
[conda] numpy 1.23.4 pypi_0 pypi
[conda] pytorch-cuda 11.7 h67b0de4_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.14.0.dev20221117+cu117 pypi_0 pypi
[conda] torchaudio 0.14.0.dev20221117 py310_cu117 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torchvision 0.15.0.dev20221117 py310_cu117 pytorch-nightly
cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin
The text was updated successfully, but these errors were encountered: