8000 [FSDP] FSDP with CPU offload consumes `1.65X` more GPU memory when training models with most of the params frozen · Issue #91165 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[FSDP] FSDP with CPU offload consumes 1.65X more GPU memory when training models with most of the params frozen #91165
Open
@pacman100

Description

@pacman100

🐛 Describe the bug

Context: We have more and more situations where a large part of the model that's being trained is frozen. As these are very large LLMs, we want to leverage FSDP with CPU offloading to fit such large model training with only a tiny fraction of training params on consumer GPUs. To this end, below is an example of finetuning bigscience/mt0-small using LoRA parameter efficient fine-tuning method with/without FSDP.

To fine-tune with FSDP:

  1. Following How do I freeze weights when using FSDP? huggingface/accelerate#807, to avoid AssertionError: expects all parameters to have same requires_grad, created a custom auto_wrap_policy such that the layers with trainable params are in separate FSDP units than those which are frozen. The result model along with FSDP options are given below. We are using Accelerate's FSDP integration . the trainable params have lora_ prefix:
FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.FULL_SHARD: 1>, backward_prefetch=<BackwardPrefetch.BACKWARD_PRE: 1>, mixed_precision_policy=None, auto_wrap_policy=functools.partial(<function _or_policy at 0x7f022c1a8430>, policies=[functools.partial(<function lambda_auto_wrap_policy at 0x7f022c1a8160>, lambda_fn=<function fsdp_auto_wrap_policy.<locals>.lambda_policy_fn at 0x7f01ec31e3b0>), functools.partial(<function transformer_auto_wrap_policy at 0x7f022c1a8310>, transformer_layer_cls=(<class 'pet.tuners.prefix_tuning.PrefixEncoder'>, <class 'pet.tuners.p_tuning.PromptEncoder'>, <class 'pet.tuners.prompt_tuning.PromptEmbedding'>, <class 'transformers.models.t5.modeling_t5.T5Block'>))]), cpu_offload=CPUOffload(offload_params=True), ignored_modules=None, state_dict_type=<StateDictType.FULL_STATE_DICT: 1>, state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), limit_all_gathers=False)


FullyShardedDataParallel(
  (_fsdp_wrapped_module): PETModelForSeq2SeqLM(
    (base_model): LoRAModel(
      (model): MT5ForConditionalGeneration(
        (shared): Embedding(250112, 512)
        (encoder): T5Stack(
          (embed_tokens): Embedding(250112, 512)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                      (relative_attention_bias): Embedding(32, 6)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (2): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (3): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (4): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (5): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (6): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (7): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (decoder): T5Stack(
          (embed_tokens): Embedding(250112, 512)
          (block): ModuleList(
            (0): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                      (relative_attention_bias): Embedding(32, 6)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (1): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (2): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (3): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (4): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (5): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (6): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
            (7): FullyShardedDataParallel(
              (_fsdp_wrapped_module): T5Block(
                (layer): ModuleList(
                  (0): T5LayerSelfAttention(
                    (SelfAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (1): T5LayerCrossAttention(
                    (EncDecAttention): T5Attention(
                      (q): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (k): Linear(in_features=512, out_features=384, bias=False)
                      (v): Linear(
                        in_features=512, out_features=384, bias=False
                        (lora_dropout): Dropout(p=0.1, inplace=False)
                        (lora_A): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=512, out_features=8, bias=False)
                        )
                        (lora_B): FullyShardedDataParallel(
                          (_fsdp_wrapped_module): Linear(in_features=8, out_features=384, bias=False)
                        )
                      )
                      (o): Linear(in_features=384, out_features=512, bias=False)
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (2): T5LayerFF(
                    (DenseReluDense): T5DenseGatedActDense(
                      (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                      (wi_1): Linear(in_features=512, out_features=1024, bias=False)
                      (wo): Linear(in_features=1024, out_features=512, bias=False)
                      (dropout): Dropout(p=0.1, inplace=False)
                      (act): NewGELUActivation()
                    )
                    (layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwi
8AA4
se_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
              )
            )
          )
          (final_layer_norm): FusedRMSNorm(torch.Size([512]), eps=1e-06, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (lm_head): Linear(in_features=512, out_features=250112, bias=False)
      )
    )
  )
)
  1. The number of trainable params are given below:
trainable params: 344064 || all params: 300520832 || trainable%: 0.11448923447676333
  1. Now, the issue is that in comparison to plain Pytorch, FSDP consumes 1.65X more GPU memory instead of reducing the same by a large amount (expectation) while also greatly increasing the memory consumed on CPU. Below are the screenshots for the same. The hardware used is 1 A100 80GB GPU.

Plain PyTorch:
Screenshot 2022-12-20 at 3 45 16 PM

FSDP Full Shard with CPU offloading:
Screenshot 2022-12-20 at 4 14 04 PM

  1. When trying to use FSDP with CPU offloading using bigscience/mt0-xxl model (13B params) on a A100 80GB GPU it results in OOM GPU error whereas Plain Pytorch consumes 56GB GPU memory.

  2. Expected behaviour: Efficiently deal with frozen weights during training such that large models could be offloaded on CPUs/sharded across GPUs properly with storage of the optimizer state only for the trainable parameters, e.g., we can see that using plain PyTorch, mt0-xxl (13B params) model takes up 56GB on GPU, now, it would be really helpful if one could do CPU offloading such that training could work on a 16GB or 24GB GPU using FSDP with CPU offloading.

Versions

Collecting environment information...
PyTorch version: 1.14.0.dev20221117+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.23.4
[pip3] torch==1.14.0.dev20221117+cu117
[pip3] torchaudio==0.14.0.dev20221117
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.15.0.dev20221117
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] msgpack-numpy 0.4.7.1 pypi_0 pypi
[conda] numpy 1.23.4 pypi_0 pypi
[conda] pytorch-cuda 11.7 h67b0de4_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.14.0.dev20221117+cu117 pypi_0 pypi
[conda] torchaudio 0.14.0.dev20221117 py310_cu117 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torchvision 0.15.0.dev20221117 py310_cu117 pytorch-nightly

cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0