8000 [BUG] Broken pipeline parallel for ESM2 · Issue #784 · NVIDIA/bionemo-framework · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[BUG] Broken pipeline parallel for ESM2 #784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dorotat-nv opened this issue Mar 25, 2025 · 1 comment · May be fixed by #829
Open
10000

[BUG] Broken pipeline parallel for ESM2 #784

dorotat-nv opened this issue Mar 25, 2025 · 1 comment · May be fixed by #829
Labels
bug Something isn't working ESM2

Comments

@dorotat-nv
Copy link
Collaborator
dorotat-nv commented Mar 25, 2025

BioNeMo Framework Version

32be470

Bug Description

The ESM2 pretraining does not work with --pipeline-model-parallel-size=2 . This option used to work ie on commit
9cec09f963cf5747193262b84bb700148280792b

train_esm2 --train-cluster-path=/data/20240809_uniref_2024_03/data/train_clusters.parquet --train-database-path=/data/20240809_uniref_2024_03/data/train.db --valid-cluster-path=/data/20240809_uniref_2024_03/data/valid_clusters.parquet --valid-database-path=/data/20240809_uniref_2024_03/data/validation.db --micro-batch-size=16 --num-nodes=2 --num-gpus=8 --val-check-interval=50 --limit-val-batches=1 --num-steps=200 --min-seq-length=1024 --max-seq-length=1024 --num-layers=33 --hidden-size=1280 --num-attention-heads=20 --ffn-hidden-size=5120 --create-tensorboard-logger --accumulate-grad-batches=1 --pipeline-model-parallel-size=2 --tensor-model-parallel-size=1 --disable-checkpointing

Steps to Reproduce

  1. Get docker image at specific commit
  2. Run the esm2 training with --pipeline-model-parallel-size=2

ie

train_esm2 --train-cluster-path=/data/20240809_uniref_2024_03/data/train_clusters.parquet --train-database-path=/data/20240809_uniref_2024_03/data/train.db --valid-cluster-path=/data/20240809_uniref_2024_03/data/valid_clusters.parquet --valid-database-path=/data/20240809_uniref_2024_03/data/validation.db --micro-batch-size=16 --num-nodes=2 --num-gpus=8 --val-check-interval=50 --limit-val-batches=1 --num-steps=200 --min-seq-length=1024 --max-seq-length=1024 --num-layers=33 --hidden-size=1280 --num-attention-heads=20 --ffn-hidden-size=5120 --create-tensorboard-logger --accumulate-grad-batches=1 --pipeline-model-parallel-size=2 --tensor-model-parallel-size=1 --disable-checkpointing

Error Messages and Logs

4: + train_esm2 --train-cluster-path=/data/20240809_uniref_2024_03/data/train_clusters.parquet --train-database-path=/data/20240809_uniref_2024_03/data/train.db --valid-cluster-path=/data/20240809_uniref_2024_03/data/valid_clusters.parquet --valid-database-path=/data/20240809_uniref_2024_03/data/validation.db --micro-batch-size=16 --num-nodes=2 --num-gpus=8 --val-check-interval=50 --limit-val-batches=1 --num-steps=200 --min-seq-length=1024 --max-seq-length=1024 --num-layers=33 --hidden-size=1280 --num-attention-heads=20 --ffn-hidden-size=5120 --create-tensorboard-logger --experiment-name=16bs_2node_8gpu_200s_bf16-mixedprec --result-dir=/jet/logs/recipe/model-esm2_variant-train_bionemo_perf_precision-bf16-mixed_nodes-2_gpus-8_batch_size-16_acc_grad-1_config_name-650m_max_steps-200_pp-2_target-dgxa100-dracooci-iad_tp-1/tensorboard_logs --wandb-project=jet--perf--main --wandb-group=esm2_train_650M__dgxa100_dracooci-iad --wandb-job-type=2025-03-2209:28:13Z_32be4700e8dd53ed01af4e749e79b2d3a6273f07_255 --log-every-n-s
 4: teps=10 --accumulate-grad-batches=1 --pipeline-model-parallel-size=2 --tensor-model-parallel-size=1 --disable-checkpointing
 1: + mkdir -p /workspace/bionemo2
 1: + cd /workspace/bionemo2
 1: + WANDB_API_KEY=[MASKED]
 1: + train_esm2 --train-cluster-path=/data/20240809_uniref_2024_03/data/train_clusters.parquet --train-database-path=/data/20240809_uniref_2024_03/data/train.db --valid-cluster-path=/data/20240809_uniref_2024_03/data/valid_clusters.parquet --valid-database-path=/data/20240809_uniref_2024_03/data/validation.db --micro-batch-size=16 --num-nodes=2 --num-gpus=8 --val-check-interval=50 --limit-val-batches=1 --num-steps=200 --min-seq-length=1024 --max-seq-length=1024 --num-layers=33 --hidden-size=1280 --num-attention-heads=20 --ffn-hidden-size=5120 --create-tensorboard-logger --experiment-name=16bs_2node_8gpu_200s_bf16-mixedprec --result-dir=/jet/logs/recipe/model-esm2_variant-train_bionemo_perf_precision-bf16-mixed_nodes-2_gpus-8_batch_size-16_acc_grad-1_config_name-650m_max_steps-200_pp-2_target-dgxa100-dracooci-iad_tp-1/tensorboard_logs --wandb-project=jet--perf--main --wandb-group=esm2_train_650M__dgxa100_dracooci-iad --wandb-job-type=2025-03-2209:28:13Z_32be4700e8dd53ed01af4e749e79b2d3a6273f07_255 --log-every-n-s
 1: teps=10 --accumulate-grad-batches=1 --pipeline-model-parallel-size=2 --tensor-model-parallel-size=1 --disable-checkpointing
 0: + mkdir -p /workspace/bionemo2
 0: + cd /workspace/bionemo2
 0: + WANDB_API_KEY=[MASKED]
 0: + train_esm2 --train-cluster-path=/data/20240809_uniref_2024_03/data/train_clusters.parquet --train-database-path=/data/20240809_uniref_2024_03/data/train.db --valid-cluster-path=/data/20240809_uniref_2024_03/data/valid_clusters.parquet --valid-database-path=/data/20240809_uniref_2024_03/data/validation.db --micro-batch-size=16 --num-nodes=2 --num-gpus=8 --val-check-interval=50 --limit-val-batches=1 --num-steps=200 --min-seq-length=1024 --max-seq-length=1024 --num-layers=33 --hidden-size=1280 --num-attention-heads=20 --ffn-hidden-size=5120 --create-tensorboard-logger --experiment-name=16bs_2node_8gpu_200s_bf16-mixedprec --result-dir=/jet/logs/recipe/model-esm2_variant-train_bionemo_perf_precision-bf16-mixed_nodes-2_gpus-8_batch_size-16_acc_grad-1_config_name-650m_max_steps-200_pp-2_target-dgxa100-dracooci-iad_tp-1/tensorboard_logs --wandb-project=jet--perf--main --wandb-group=esm2_train_650M__dgxa100_dracooci-iad --wandb-job-type=2025-03-2209:28:13Z_32be4700e8dd53ed01af4e749e79b2d3a6273f07_255 --log-every-n-s
 0: teps=10 --accumulate-grad-batches=1 --pipeline-model-parallel-size=2 --tensor-model-parallel-size=1 --disable-checkpointing
15: Collecting system details... DONE (took: 4.75e+01[s], started: 1742637865.064)
15: Collecting libraries versions... DONE (took: 5.60e-01[s], started: 1742637912.612)
15: Saving the log... DONE (took: 4.41e-02[s], started: 1742637913.172)
15: 
15: [WORKLOAD] [1/1] stage=script
15: + mkdir -p /workspace/bionemo2
15: + cd /workspace/bionemo2
15: + WANDB_API_KEY=[MASKED]
15: + train_esm2 --train-cluster-path=/data/20240809_uniref_2024_03/data/train_clusters.parquet --train-database-path=/data/20240809_uniref_2024_03/data/train.db --valid-cluster-path=/data/20240809_uniref_2024_03/data/valid_clusters.parquet --valid-database-path=/data/20240809_uniref_2024_03/data/validation.db --micro-batch-size=16 --num-nodes=2 --num-gpus=8 --val-check-interval=50 --limit-val-batches=1 --num-steps=200 --min-seq-length=1024 --max-seq-length=1024 --num-layers=33 --hidden-size=1280 --num-attention-heads=20 --ffn-hidden-size=5120 --create-tensorboard-logger --experiment-name=16bs_2node_8gpu_200s_bf16-mixedprec --result-dir=/jet/logs/recipe/model-esm2_variant-train_bionemo_perf_precision-bf16-mixed_nodes-2_gpus-8_batch_size-16_acc_grad-1_config_name-650m_max_steps-200_pp-2_target-dgxa100-dracooci-iad_tp-1/tensorboard_logs --wandb-project=jet--perf--main --wandb-group=esm2_train_650M__dgxa100_dracooci-iad --wandb-job-type=2025-03-2209:28:13Z_32be4700e8dd53ed01af4e749e79b2d3a6273f07_255 --log-every-n-s
15: teps=10 --accumulate-grad-batches=1 --pipeline-model-parallel-size=2 --tensor-model-parallel-size=1 --disable-checkpointing
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
10:   @custom_fwd
10: 
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
10:   @custom_bwd
10: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
11:   @custom_fwd
11: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
11:   @custom_bwd
11: 
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
10:   @custom_fwd
10: 
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
10:   @custom_bwd
10: 
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
10:   @custom_fwd
10: 
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
10:   @custom_bwd
10: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
11:   @custom_fwd
11: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
11:   @custom_bwd
11: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
11:   @custom_fwd
11: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
11:   @custom_bwd
11: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
14:   @custom_fwd
14: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
14:   @custom_bwd
14: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
14:   @custom_fwd
14: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
14:   @custom_bwd
14: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
14:   @custom_fwd
14: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
14:   @custom_bwd
14: 
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
10:   @custom_fwd
10: 
10: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
10:   @custom_bwd
10: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
11:   @custom_fwd
11: 
11: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
11:   @custom_bwd
11: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
14:   @custom_fwd
14: 
14: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
14:   @custom_bwd
14: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
12:   @custom_fwd
12: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
12:   @custom_bwd
12: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
12:   @custom_fwd
12: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
12:   @custom_bwd
12: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
12:   @custom_fwd
12: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
12:   @custom_bwd
12: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
12:   @custom_fwd
12: 
12: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
12:   @custom_bwd
12: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 8:   @custom_fwd
 8: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 8:   @custom_bwd
 8: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 8:   @custom_fwd
 8: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 8:   @custom_bwd
 8: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 8:   @custom_fwd
 8: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 8:   @custom_bwd
 8: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 8:   @custom_fwd
 8: 
 8: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 8:   @custom_bwd
 8: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
13:   @custom_fwd
13: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
13:   @custom_bwd
13: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
13:   @custom_fwd
13: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
13:   @custom_bwd
13: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
13:   @custom_fwd
13: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
13:   @custom_bwd
13: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 9:   @custom_fwd
 9: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 9:   @custom_bwd
 9: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 9:   @custom_fwd
 9: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 9:   @custom_bwd
 9: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 9:   @custom_fwd
 9: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 9:   @custom_bwd
 9: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
13:   @custom_fwd
13: 
13: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
13:   @custom_bwd
13: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 9:   @custom_fwd
 9: 
 9: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 9:   @custom_bwd
 9: 
 8: Initializing distributed: GLOBAL_RANK: 8, MEMBER: 9/16
10: Initializing distributed: GLOBAL_RANK: 10, MEMBER: 11/16
14: Initializing distributed: GLOBAL_RANK: 14, MEMBER: 15/16
11: Initializing distributed: GLOBAL_RANK: 11, MEMBER: 12/16
12: Initializing distributed: GLOBAL_RANK: 12, MEMBER: 13/16
13: Initializing distributed: GLOBAL_RANK: 13, MEMBER: 14/16
 9: Initializing distributed: GLOBAL_RANK: 9, MEMBER: 10/16
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 3:   @custom_fwd
 3: 
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 3:   @custom_bwd
 3: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 7:   @custom_fwd
 7: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 7:   @custom_bwd
 7: 
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 3:   @custom_fwd
 3: 
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 3:   @custom_bwd
 3: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 6:   @custom_fwd
 6: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 6:   @custom_bwd
 6: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 7:   @custom_fwd
 7: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 7:   @custom_bwd
 7: 
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 3:   @custom_fwd
 3: 
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 3:   @custom_bwd
 3: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 7:   @custom_fwd
 7: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 7:   @custom_bwd
 7: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 6:   @custom_fwd
 6: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 6:   @custom_bwd
 6: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 6:   @custom_fwd
 6: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 6:   @custom_bwd
 6: 
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 3:   @custom_fwd
 3: 
 3: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 3:   @custom_bwd
 3: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 7:   @custom_fwd
 7: 
 7: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 7:   @custom_bwd
 7: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 6:   @custom_fwd
 6: 
 6: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 6:   @custom_bwd
 6: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 2:   @custom_fwd
 2: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 2:   @custom_bwd
 2: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 2:   @custom_fwd
 2: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 2:   @custom_bwd
 2: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 2:   @custom_fwd
 2: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 2:   @custom_bwd
 2: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 2:   @custom_fwd
 2: 
 2: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 2:   @custom_bwd
 2: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 5:   @custom_fwd
 5: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 5:   @custom_bwd
 5: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 5:   @custom_fwd
 5: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 5:   @custom_bwd
 5: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 5:   @custom_fwd
 5: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 5:   @custom_bwd
 5: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 5:   @custom_fwd
 5: 
 5: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 5:   @custom_bwd
 5: 
 3: Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/16
 7: Initializing distributed: GLOBAL_RANK: 7, MEMBER: 8/16
 6: Initializing distributed: GLOBAL_RANK: 6, MEMBER: 7/16
 2: Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/16
 5: Initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/16
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 1:   @custom_fwd
 1: 
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 1:   @custom_bwd
 1: 
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 1:   @custom_fwd
 1: 
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 1:   @custom_bwd
 1: 
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 1:   @custom_fwd
 1: 
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 1:   @custom_bwd
 1: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 4:   @custom_fwd
 4: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 4:   @custom_bwd
 4: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 4:   @custom_fwd
 4: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 4:   @custom_bwd
 4: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 4:   @custom_fwd
 4: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 4:   @custom_bwd
 4: 
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 1:   @custom_fwd
 1: 
 1: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 1:   @custom_bwd
 1: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 4:   @custom_fwd
 4: 
 4: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 4:   @custom_bwd
 4: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 0:   @custom_fwd
 0: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 0:   @custom_bwd
 0: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 0:   @custom_fwd
 0: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 0:   @custom_bwd
 0: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 0:   @custom_fwd
 0: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 0:   @custom_bwd
 0: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
 0:   @custom_fwd
 0: 
 0: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
 0:   @custom_bwd
 0: 
 0: [INFO     | pytorch_lightning.utilities.rank_zero]: Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
 0: [INFO     | pytorch_lightning.utilities.rank_zero]: GPU available: True (cuda), used: True
 0: [INFO     | pytorch_lightning.utilities.rank_zero]: TPU available: False, using: 0 TPU cores
 0: [INFO     | pytorch_lightning.utilities.rank_zero]: HPU available: False, using: 0 HPUs
 0: [INFO     | pytorch_lightning.utilities.rank_zero]: `Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Experiments will be logged at /jet/logs/recipe/model-esm2_variant-train_bionemo_perf_precision-bf16-mixed_nodes-2_gpus-8_batch_size-16_acc_grad-1_config_name-650m_max_steps-200_pp-2_target-dgxa100-dracooci-iad_tp-1/tensorboard_logs/16bs_2node_8gpu_200s_bf16-mixedprec/dev
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has data parallel group : [0, 1, 2, 3, 4, 5, 6, 7]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0, 1, 2, 3, 4, 5, 6, 7]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15]]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Ranks 0 has data parallel rank: 0
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has context parallel group: [0]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] All context parallel group ranks: [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15]]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Ranks 0 has context parallel rank: 0
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has model parallel group: [0, 8]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] All model parallel group ranks: [[0, 8], [1, 9], [2, 10], [3, 11], [4, 12], [5, 13], [6, 14], [7, 15]]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has tensor model parallel group: [0]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] All tensor model parallel group ranks: [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15]]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has tensor model parallel rank: 0
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has pipeline model parallel group: [0, 8]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has embedding group: [0, 8]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] All pipeline model parallel group ranks: [[0, 8], [1, 9], [2, 10], [3, 11], [4, 12], [5, 13], [6, 14], [7, 15]]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has pipeline model parallel rank 0
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] All embedding group ranks: [[0, 8], [1, 9], [2, 10], [3, 11], [4, 12], [5, 13], [6, 14], [7, 15]]
 0: [NeMo I 2025-03-22 03:05:22 nemo_logging:393] Rank 0 has embedding rank: 0
 0: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/16
 1: Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/16
 4: Initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/16
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
15:   @custom_fwd
15: 
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:239: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
15:   @custom_bwd
15: 
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:985: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
15:   @custom_fwd
15: 
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/layer_norm.py:1044: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
15:   @custom_bwd
15: 
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:25: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
15:   @custom_fwd
15: 
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/distributed/tensor_parallel.py:61: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
15:   @custom_bwd
15: 
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:757: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
15:   @custom_fwd
15: 
15: [WARNING  | py.warnings        ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py:835: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
15:   @custom_bwd
15: 
15: Initializing distributed: GLOBAL_RANK: 15, MEMBER: 16/16
 0: [INFO     | pytorch_lightning.utilities.rank_zero]: ----------------------------------------------------------------------------------------------------
 0: distributed_backend=nccl
 0: All distributed processes registered. Starting with 16 processes
 0: ----------------------------------------------------------------------------------------------------
 0: 
 0: wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
 0: wandb: Currently logged in as: dorotat_nv (clara-discovery) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
 0: wandb: Tracking run with wandb version 0.19.8
 0: wandb: Run data is saved locally in /jet/logs/recipe/model-esm2_variant-train_bionemo_perf_precision-bf16-mixed_nodes-2_gpus-8_batch_size-16_acc_grad-1_config_name-650m_max_steps-200_pp-2_target-dgxa100-dracooci-iad_tp-1/tensorboard_logs/16bs_2node_8gpu_200s_bf16-mixedprec/wandb/run-20250322_030528-ug1tfvgh
 0: wandb: Run `wandb offline` to turn off syncing.
 0: wandb: Syncing run 16bs_2node_8gpu_200s_bf16-mixedprec
 0: wandb: ⭐️ View project at https://wandb.ai/clara-discovery/jet--perf--main
 0: wandb: 🚀 View run at https://wandb.ai/clara-discovery/jet--perf--main/runs/ug1tfvgh
 0: [NeMo I 2025-03-22 03:06:55 nemo_logging:393] Padded vocab_size: 128, original vocab_size: 33, dummy tokens: 95.
 0: Traceback (most recent call last):
10: [rank10]: Traceback (most recent call last):
10: [rank10]:   File "/usr/local/bin/train_esm2", line 8, in <module>
10: [rank10]:     sys.exit(train_esm2_entrypoint())
10: [rank10]:              ^^^^^^^^^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 386, in train_esm2_entrypoint
10: [rank10]:     main(
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 370, in main
10: [rank10]:     llm.train(
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/api.py", line 109, in train
10: [rank10]:     trainer.fit(model, data)
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
10: [rank10]:     call._call_and_handle_interrupt(
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
 0:   File "/usr/local/bin/train_esm2", line 8, in <module>
 0:     sys.exit(train_esm2_entrypoint())
 0:              ^^^^^^^^^^^^^^^^^^^^^^^
10: [rank10]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
10: [rank10]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
10: [rank10]:     return function(*args, **kwargs)
10: [rank10]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
10: [rank10]:     self._run(model, ckpt_path=ckpt_path)
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 945, in _run
10: [rank10]:     call._call_configure_model(self)
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 119, in _call_configure_model
10: [rank10]:     _call_lightning_module_hook(trainer, "configure_model")
 0:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 386, in train_esm2_entrypoint
 0:    
8000
 main(
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
10: [rank10]:     output = fn(*args, **kwargs)
10: [rank10]:              ^^^^^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/lightning.py", line 306, in configure_model
10: [rank10]:     self.config.configure_model(**self.module_construct_args)
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/model/biobert/model.py", line 558, in configure_model
10: [rank10]:     model = self.model_cls(
10: [rank10]:             ^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/model/model.py", line 166, in __init__
10: [rank10]:     self.encoder = TransformerBlock(
10: [rank10]:                    ^^^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 224, in __init__
10: [rank10]:     self.submodules = _get_block_submodules(config, spec)
 0:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 370, in main
 0:     llm.train(
10: [rank10]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 201, in _get_block_submodules
10: [rank10]:     num_layers = get_num_layers_to_build(config)
10: [rank10]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10: [rank10]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 110, in get_num_layers_to_build
10: [rank10]:     num_layers % config.pipeline_model_parallel_size == 0
10: [rank10]: AssertionError: num_layers should be divisible by pipeline_model_parallel_size
 0:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/api.py", line 109, in train
 0:     trainer.fit(model, data)
15: [rank15]: Traceback (most recent call last):
15: [rank15]:   File "/usr/local/bin/train_esm2", line 8, in <module>
15: [rank15]:     sys.exit(train_esm2_entrypoint())
15: [rank15]:              ^^^^^^^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 386, in train_esm2_entrypoint
15: [rank15]:     main(
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 370, in main
15: [rank15]:     llm.train(
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/api.py", line 109, in train
15: [rank15]:     trainer.fit(model, data)
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
15: [rank15]:     call._call_and_handle_interrupt(
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
 0:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
 0:     call._call_and_handle_interrupt(
15: [rank15]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
15: [rank15]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
15: [rank15]:     return function(*args, **kwargs)
15: [rank15]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
15: [rank15]:     self._run(model, ckpt_path=ckpt_path)
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 945, in _run
15: [rank15]:     call._call_configure_model(self)
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 119, in _call_configure_model
15: [rank15]:     _call_lightning_module_hook(trainer, "configure_model")
 0:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
 0:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
 0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
15: [rank15]:     output = fn(*args, **kwargs)
15: [rank15]:              ^^^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/lightning.py", line 306, in configure_model
15: [rank15]:     self.config.configure_model(**self.module_construct_args)
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/model/biobert/model.py", line 558, in configure_model
15: [rank15]:     model = self.model_cls(
15: [rank15]:             ^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/model/model.py", line 166, in __init__
15: [rank15]:     self.encoder = TransformerBlock(
15: [rank15]:                    ^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 224, in __init__
15: [rank15]:     self.submodules = _get_block_submodules(config, spec)
 2: [rank2]: Traceback (most recent call last):
 2: [rank2]:   File "/usr/local/bin/train_esm2", line 8, in <module>
 2: [rank2]:     sys.exit(train_esm2_entrypoint())
 2: [rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 386, in train_esm2_entrypoint
 2: [rank2]:     main(
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 370, in main
 2: [rank2]:     llm.train(
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/api.py", line 109, in train
 2: [rank2]:     trainer.fit(model, data)
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
 2: [rank2]:     call._call_and_handle_interrupt(
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
 2: [rank2]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
15: [rank15]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 201, in _get_block_submodules
15: [rank15]:     num_layers = get_num_layers_to_build(config)
15: [rank15]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15: [rank15]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 110, in get_num_layers_to_build
15: [rank15]:     num_layers % config.pipeline_model_parallel_size == 0
15: [rank15]: AssertionError: num_layers should be divisible by pipeline_model_parallel_size
 2: [rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
 2: [rank2]:     return function(*args, **kwargs)
 2: [rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
 2: [rank2]:     self._run(model, ckpt_path=ckpt_path)
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 945, in _run
 2: [rank2]:     call._call_configure_model(self)
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 119, in _call_configure_model
 2: [rank2]:     _call_lightning_module_hook(trainer, "configure_model")
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
11: [rank11]: Traceback (most recent call last):
11: [rank11]:   File "/usr/local/bin/train_esm2", line 8, in <module>
11: [rank11]:     sys.exit(train_esm2_entrypoint())
11: [rank11]:              ^^^^^^^^^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 386, in train_esm2_entrypoint
11: [rank11]:     main(
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/scripts/train_esm2.py", line 370, in main
11: [rank11]:     llm.train(
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/api.py", line 109, in train
11: [rank11]:     trainer.fit(model, data)
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
11: [rank11]:     call._call_and_handle_interrupt(
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
 2: [rank2]:     output = fn(*args, **kwargs)
 2: [rank2]:              ^^^^^^^^^^^^^^^^^^^
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/lightning.py", line 306, in configure_model
 2: [rank2]:     self.config.configure_model(**self.module_construct_args)
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/model/biobert/model.py", line 558, in configure_model
 2: [rank2]:     model = self.model_cls(
 2: [rank2]:             ^^^^^^^^^^^^^^^
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/model/model.py", line 166, in __init__
 2: [rank2]:     self.encoder = TransformerBlock(
 2: [rank2]:                    ^^^^^^^^^^^^^^^^^
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 224, in __init__
 2: [rank2]:     self.submodules = _get_block_submodules(config, spec)
 2: [rank2]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
11: [rank11]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
11: [rank11]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
11: [rank11]:     return function(*args, **kwargs)
11: [rank11]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
11: [rank11]:     self._run(model, ckpt_path=ckpt_path)
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 945, in _run
11: [rank11]:     call._call_configure_model(self)
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 119, in _call_configure_model
11: [rank11]:     _call_lightning_module_hook(trainer, "configure_model")
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 201, in _get_block_submodules
 2: [rank2]:     num_layers = get_num_layers_to_build(config)
 2: [rank2]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 2: [rank2]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 110, in get_num_layers_to_build
 2: [rank2]:     num_layers % config.pipeline_model_parallel_size == 0
 2: [rank2]: AssertionError: num_layers should be divisible by pipeline_model_parallel_size
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
11: [rank11]:     output = fn(*args, **kwargs)
11: [rank11]:              ^^^^^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/lightning.py", line 306, in configure_model
11: [rank11]:     self.config.configure_model(**self.module_construct_args)
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/llm/model/biobert/model.py", line 558, in configure_model
11: [rank11]:     model = self.model_cls(
11: [rank11]:             ^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/esm2/model/model.py", line 166, in __init__
11: [rank11]:     self.encoder = TransformerBlock(
11: [rank11]:                    ^^^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 224, in __init__
11: [rank11]:     self.submodules = _get_block_submodules(config, spec)
 0:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
 0:     return function(*args, **kwargs)
 0:            ^^^^^^^^^^^^^^^^^^^^^^^^^
11: [rank11]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 201, in _get_block_submodules
11: [rank11]:     num_layers = get_num_layers_to_build(config)
11: [rank11]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
11: [rank11]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 110, in get_num_layers_to_build
11: [rank11]:     num_layers % config.pipeline_model_parallel_size == 0
11: [rank11]: AssertionError: num_layers should be divisible by pipeline_model_parallel_size

Docker Image

No response

System Information

Environment Details:

  • OS: [e.g., Ubuntu 20.04]
  • CPU: [e.g., Intel i9-12900K]
  • RAM: [e.g., 64GB]

GPU Details:

  • GPU Model: [e.g., NVIDIA RTX 4090]
  • GPU Memory: [e.g., 24GB]
  • CUDA Version: [e.g., 12.1]
  • CUDA Driver: [e.g., 525.85.05]
  • cuDNN Version: [e.g., 8.9.0]

Additional Context

No response

@dorotat-nv dorotat-nv added bug Something isn't working ESM2 labels Mar 25, 2025
@sichu2023
Copy link
Collaborator

This should be related to this particular configuration, i.e. 33 layers on pipeline parallel 2. How about trying 32 layers instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ESM2
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants
0