-
Notifications
You must be signed in to change notification settings - Fork 3k
[llm]update dpo criterion #9620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your contribution! |
@@ -148,16 +148,25 @@ def dpo_logps( | |||
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: | |||
labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, 0) | |||
|
|||
hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx, axis=0) | |||
hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gather api代替take_along_axis提高效率
paddlenlp/trl/dpo_criterion.py
Outdated
hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx.unsqueeze(-1), axis=0) | ||
|
||
hidden_states = paddle.gather(hidden_states, sparse_tgt_idx, axis=0) | ||
elif self.config.use_fused_head_and_loss_fn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前缺少sequence parallel的梯度计算补上
[(per_token_logps[response_index[1] : response_index[2]]).sum() for response_index in response_indexs], | ||
[ | ||
( | ||
paddle.gather( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gather 代替slice,gpu和npu能够通用
@@ -194,64 +203,65 @@ def dpo_logps( | |||
|
|||
if len(response_indexs.shape) == 3: | |||
response_indexs = response_indexs[0] | |||
|
|||
offset = 1 if self.ignore_eos_token else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offset是因为希望rm和dpo能共用一个数据流,对于dpo来说 ignore_eos_token默认为True
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9620 +/- ##
===========================================
- Coverage 53.10% 52.98% -0.13%
===========================================
Files 704 708 +4
Lines 110967 111168 +201
===========================================
- Hits 58925 58898 -27
- Misses 52042 52270 +228 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
APIs
Description
优化dpo criterion