8000 [DEV] Support sync params in tensor parallel config by From00 · Pull Request #8311 · PaddlePaddle/PaddleNLP · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[DEV] Support sync params in tensor parallel config #8311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,20 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
default -1 for not use tensor parallel, Suggest tensor_parallel_degree<=8 for better proformance.
Note, this need model support in source code, currently GPT/BLOOM/LLAMA/BLOOM/CLM/CHATGLM is supported.

--tensor_parallel_config
对于张量并行,一些选项会影响训练性能,这里将一些选项配置集中管理,以str形式传入配置.
支持如下选项:
enable_delay_scale_loss : 在优化器阶段做梯度累加,将所有梯度除以累加次数,而不是直接对loss除以累加次数。
sync_param : 在优化器阶段使用broadcast同步所有is_distributed=False的参数
sync_grad : 在优化器阶段使用broadcast同步所有is_distributed=False的梯度
sync_moment : 在优化器阶段使用broadcast同步所有is_distributed=False的momentum
Comment on lines +528 to +530
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to ot 10000 hers. Learn more.

这些不是默认开启的吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PaddleNLP不配置这些参数的时候,不会影响现行逻辑。
框架里,sync_param默认开启,sync_grad和sync_moment默认不开启,且同步的参数名称默认为sync_param_name = ["embedding", "layer_norm", ".b_"],其它参数都不会同步。
PaddleNLP配置开启这些开关的时候,会强制同步所有参数。
在代码里有相关注释说明:https://github.com/PaddlePaddle/PaddleNLP/pull/8311/files#diff-477a2a51a1a5694f5db999c8695a3f6ec8fd4f08ded299fb66176651e9d6ebadR1100-R1102


Some additional config it highly affect the usage of tensor parallel, we provide some option to config it.
following config is support:
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.
sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.
sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.

--pipeline_parallel_degree
流水线并行是Megatron论文针对多层Transformer结构提出的按层划分方法.
Expand Down Expand Up @@ -549,7 +563,7 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
following config is support:
disable_p2p_cache_shape, if you max sequence length is varying, please set disable_p2p_cache_shape.
disable_partial_send_recv, optmize send speed for tensor parallel.
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
enable_dp_comm_overlap, fuse data parallel gradient communication.

--data_parallel_config
Expand Down
36 changes: 31 additions & 5 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,16 @@ class TrainingArguments:
enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance.
enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.
enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.
sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.
sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.
pipeline_parallel_config (`str`, *optional*)(
Some additional config it highly affect the useage of pipeline parallel, we provide some option to config it.
following config is support:
disable_p2p_cache_shape, if you max sequence length is varying, please set disable_p2p_cache_shape.
disable_partial_send_recv, optmize send speed for tensor parallel.
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
enable_dp_comm_overlap, fuse data parallel gradient communication.
enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication.
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
Expand Down Expand Up @@ -600,7 +603,10 @@ class TrainingArguments:
"enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. \n"
"enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.\n"
"enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.\n"
"enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n"
"enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n"
"sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n"
"sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n"
"sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.\n"
)
},
)
Expand All @@ -612,7 +618,7 @@ class TrainingArguments:
"following config is support:\n"
"disable_p2p_cache_shape, if you max sequence length is varying, please set disable_p2p_cache_shape. \n"
"disable_partial_send_recv, optmize send speed for tensor parallel.\n"
"enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.\n"
"enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.\n"
"enable_dp_comm_overlap, fuse data parallel gradient communication. \n"
"enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication. \n"
"enable_overlap_p2p_comm, overlap p2p communication with computation. \n"
Expand Down Expand Up @@ -1062,10 +1068,13 @@ def __post_init__(self):
"enable_mp_skip_c_identity",
"enable_mp_fused_linear_param_grad_add",
"enable_delay_scale_loss",
"sync_param",
"sync_grad",
"sync_moment",
]:
raise ValueError(
f"Found unknown tensor parallell config {x}, "
f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add"
f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, sync_param, sync_grad and sync_moment."
)
try:
if "enable_mp_async_allreduce" in mp_config:
Expand All @@ -1083,6 +1092,23 @@ def __post_init__(self):
warnings.warn(
"enable_mp_fused_linear_param_grad_add only works with enable_mp_async_allreduce. It will not work."
)

sync_param = "sync_param" in mp_config
sync_grad = "sync_grad" in mp_config
sync_moment = "sync_moment" in mp_config

# sync_param_name = [""] matches any parameter name.
# If sync_param, sync_grad and sync_moment are not set, the default value in Paddle is :
# sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"].
if sync_param:
strategy.hybrid_configs["mp_configs"].sync_param = True
strategy.hybrid_configs["mp_configs"].sync_param_name = [""]
if sync_grad:
strategy.hybrid_configs["mp_configs"].sync_grad = True
strategy.hybrid_configs["mp_configs"].sync_grad_name = [""]
if sync_moment:
strategy.hybrid_configs["mp_configs"].sync_moment = True
strategy.hybrid_configs["mp_configs"].sync_moment_name = [""]
except:
warnings.warn(
"The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported "
Expand Down
Loading
0