Description
🐛 Describe the bug
Context: We have more and more situations where a large part of the model that's being trained is frozen. As these are very large LLMs, we want to leverage FSDP with CPU offloading to fit such large model training with only a tiny fraction of training params on consumer GPUs. To this end, below is an example of finetuning bigscience/mt0-small
using LoRA parameter efficient fine-tuning method with/without FSDP.
To fine-tune with FSDP:
- Following How do I freeze weights when using FSDP? huggingface/accelerate#807, to avoid
AssertionError: expects all parameters to have same requires_grad
, created a customauto_wrap_policy
such that the layers with trainable params are in separate FSDP units than those which are frozen. The result model along with FSDP options are given below. We are using Accelerate's FSDP integration . the trainable params havelora_
prefix:
FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.FULL_SHARD: 1>, backward_prefetch=<BackwardPrefetch.BACKWARD_PRE: 1>, mixed_precision_policy=None, auto_wrap_policy=functools.partial(<function _or_policy at 0x7f022c1a8430>, policies=[functools.partial(<function lambda_auto_wrap_policy at 0x7f022c1a8160>, lambda_fn=<function fsdp_auto_wrap_policy.<locals>.lambda_policy_fn at 0x7f01ec31e3b0>), functools.partial(<function transformer_auto_wrap_policy at 0x7f022c1a8310>, transformer_layer_cls=(<class 'pet.tuners.prefix_tuning.PrefixEncoder'>, <class 'pet.tuners.p_tuning.PromptEncoder'>, <class 'pet.tuners.prompt_tuning.PromptEmbedding'>, <class 'transformers.models.t5.modeling_t5.T5Block'>))]), cpu_offload=CPUOffload(offload_params=True), ignored_modules=None, state_dict_type=<StateDictType.FULL_STATE_DICT: 1>, state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), limit_all_gathers=False)
FullyShardedDataParallel(
(_fsdp_wrapped_module): PETModelForSeq2SeqLM(
(base_model): LoRAModel(
(model): MT5ForConditionalGeneration(
(shared): Embedding(250112, 512)
(encoder): T5Stack(
(embed_tokens): Embedding(250112, 512)
(block): ModuleList(
(0): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
(relative_attention_bias): Embedding(32, 6)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(1): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(2): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(3): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(4): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(5): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(6): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(7): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(decoder): T5Stack(
(embed_tokens): Embedding(250112, 512)
(block): ModuleList(
(0): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
(relative_attention_bias): Embedding(32, 6)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(1): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(2): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(3): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(4): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(5): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(6): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(7): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(k): Linear(in_features=512, out_features=384, bias=False)
(v): Linear(
in_features=512, out_features=384, bias=False
(lora_dropout): Dropout(p=0.1, inplace=False)
(lora_A): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
)
(lora_B): FullyShardedDataParallel(
(_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
)
)
(o): Linear(in_features=384, out_features=512, bias=False)
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=512, out_features=1024, bias=False)
(wi_1): Linear(in_features=512, out_features=1024, bias=False)
(wo): Linear(in_features=1024, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwi
8AA4
se_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(lm_head): Linear(in_features=512, out_features=250112, bias=False)
)
)
)
)
- The number of trainable params are given below:
trainable params: 344064 || all params: 300520832 || trainable%: 0.11448923447676333
- Now, the issue is that in comparison to plain Pytorch, FSDP consumes 1.65X more GPU memory instead of reducing the same by a large amount (expectation) while also greatly increasing the memory consumed on CPU. Below are the screenshots for the same. The hardware used is 1 A100 80GB GPU.
FSDP Full Shard with CPU offloading
:
-
When trying to use FSDP with CPU offloading using
bigscience/mt0-xxl
model (13B params) on a A100 80GB GPU it results in OOM GPU error whereas Plain Pytorch consumes 56GB GPU memory. -
Expected behaviour: Efficiently deal with frozen weights during training such that large models could be offloaded on CPUs/sharded across GPUs properly with storage of the optimizer state only for the trainable parameters, e.g., we can see that using plain PyTorch, mt0-xxl (13B params) model takes up 56GB on GPU, now, it would be really helpful if one could do CPU offloading such that training could work on a 16GB or 24GB GPU using FSDP with CPU offloading.
Versions
Collecting environment information...
PyTorch version: 1.14.0.dev20221117+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31
Python version: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.23.4
[pip3] torch==1.14.0.dev20221117+cu117
[pip3] torchaudio==0.14.0.dev20221117
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.15.0.dev20221117
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] msgpack-numpy 0.4.7.1 pypi_0 pypi
[conda] numpy 1.23.4 pypi_0 pypi
[conda] pytorch-cuda 11.7 h67b0de4_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.14.0.dev20221117+cu117 pypi_0 pypi
[conda] torchaudio 0.14.0.dev20221117 py310_cu117 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torchvision 0.15.0.dev20221117 py310_cu117 pytorch-nightly
cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin